diff --git a/.github/workflows/docs-integration-tests.yml b/.github/workflows/docs-integration-tests.yml index ad4457a20..2f1dba285 100644 --- a/.github/workflows/docs-integration-tests.yml +++ b/.github/workflows/docs-integration-tests.yml @@ -63,6 +63,9 @@ jobs: GOOGLE_AUTH_URI: ${{ secrets.INTEG_GOOGLE_AUTH_URI }} GOOGLE_TOKEN_URI: ${{ secrets.INTEG_GOOGLE_TOKEN_URI }} GOOGLE_AUTH_PROVIDER_X509_CERT_URL: ${{ secrets.INTEG_GOOGLE_AUTH_PROVIDER_X509_CERT_URL }} + GRIPTAPE_CLOUD_API_KEY: ${{ secrets.INTEG_GRIPTAPE_CLOUD_API_KEY }} + GRIPTAPE_CLOUD_STRUCTURE_ID: ${{ secrets.INTEG_GRIPTAPE_CLOUD_STRUCTURE_ID }} + GRIPTAPE_CLOUD_BASE_URL: ${{ secrets.INTEG_GRIPTAPE_CLOUD_BASE_URL }} OPENWEATHER_API_KEY: ${{ secrets.INTEG_OPENWEATHER_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.INTEG_ANTHROPIC_API_KEY }} SAGEMAKER_LLAMA_ENDPOINT_NAME: ${{ secrets.INTEG_LLAMA_ENDPOINT_NAME }} @@ -92,6 +95,11 @@ jobs: POSTGRES_HOST: ${{ secrets.INTEG_POSTGRES_HOST }} POSTGRES_PORT: ${{ secrets.INTEG_POSTGRES_PORT }} VOYAGE_API_KEY: ${{ secrets.INTEG_VOYAGE_API_KEY }} + WEBHOOK_URL: ${{ secrets.INTEG_WEBHOOK_URL }} + AMAZON_SQS_QUEUE_URL: ${{ secrets.INTEG_AMAZON_SQS_QUEUE_URL }} + GT_CLOUD_STRUCTURE_RUN_ID: ${{ secrets.INTEG_GT_CLOUD_STRUCTURE_RUN_ID }} + AWS_IOT_CORE_ENDPOINT: ${{ secrets.INTEG_AWS_IOT_CORE_ENDPOINT }} + AWS_IOT_CORE_TOPIC: ${{ secrets.INTEG_AWS_IOT_CORE_TOPIC }} services: postgres: image: ankane/pgvector:v0.5.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d16348bc..c2b9403b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Changed +- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver. + +## [0.25.1] - 2024-05-09 +### Added +- Optional event batching on Event Listener Drivers. +- `id` field to all events. + +### Changed +- Default behavior of Event Listener Drivers to batch events. + +## [0.25.0] - 2024-05-06 + ### Added - `list_files_from_disk` activity to `FileManager` Tool. - Support for Drivers in `EventListener`. @@ -13,10 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AwsIotCoreEventListenerDriver` for sending events to a topic on AWS IoT Core. - `GriptapeCloudEventListenerDriver` for sending events to Griptape Cloud. - `WebhookEventListenerDriver` for sending events to a webhook. -- `LocalEventListenerDriver` for sending events to a callback function. - `BaseFileManagerDriver` to abstract file management operations. - `LocalFileManagerDriver` for managing files on the local file system. -- Added optional `BaseLoader.encoding` field. +- Optional `BaseLoader.encoding` field. - `BlobLoader` for loading arbitrary binary data as a `BlobArtifact`. - `model` field to `StartPromptEvent` and `FinishPromptEvent`. - `input_task_input` and `input_task_output` fields to `StartStructureRunEvent`. @@ -24,6 +36,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AmazonS3FileManagerDriver` for managing files on Amazon S3. - `MediaArtifact` as a base class for `ImageArtifact` and future media Artifacts. - Optional `exception` field to `ErrorArtifact`. +- `StructureRunClient` for running other Structures via a Tool. +- `StructureRunTask` for running Structures as a Task from within another Structure. +- `GriptapeCloudStructureRunDriver` for running Structures in Griptape Cloud. +- `LocalStructureRunDriver` for running Structures in the same run-time environment as the code that is running the Structure. ### Changed - **BREAKING**: Secret fields (ex: api_key) removed from serialized Drivers. @@ -32,11 +48,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `PdfLoader` no longer accepts `str` file content, `Path` file paths or `IO` objects as sources. Instead, it will only accept the content of the PDF file as a `bytes` object. - **BREAKING**: `TextLoader` no longer accepts `Path` file paths as a source. It will now accept the content of the text file as a `str` or `bytes` object. - **BREAKING**: `FileManager.default_loader` is now `None` by default. -- **BREAKING**: Replaced `EventListener.handler` with `EventListener.driver` and `LocalEventListenerDriver`. -- Improved RAG performance in `VectorQueryEngine`. +- **BREAKING** Bumped `pinecone` from `^2` to `^3`. - **BREAKING**: Removed `workdir`, `loaders`, `default_loader`, and `save_file_encoding` fields from `FileManager` and added `file_manager_driver`. -- **BREADKING**: Removed `mime_type` field from `ImageArtifact`. `mime_type` is now a property constructed using the Artifact type and `format` field. +- **BREAKING**: Removed `mime_type` field from `ImageArtifact`. `mime_type` is now a property constructed using the Artifact type and `format` field. +- Improved RAG performance in `VectorQueryEngine`. - Moved [Griptape Docs](https://github.com/griptape-ai/griptape-docs) to this repository. +- Updated `EventListener.handler`'s behavior so that the return value will be passed to the `EventListenerDriver.try_publish_event_payload`'s `event_payload` parameter. ### Fixed - Type hint for parameter `azure_ad_token_provider` on Azure OpenAI drivers to `Optional[Callable[[], str]]`. @@ -122,13 +139,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ImageQueryTask` and `ImageQueryEngine`. ### Fixed -- `BedrockStableDiffusionImageGenerationModelDriver` request parameters for SDXLv1. +- `BedrockStableDiffusionImageGenerationModelDriver` request parameters for SDXLv1 (`stability.stable-diffusion-xl-v1`). - `BedrockStableDiffusionImageGenerationModelDriver` correctly handles the CONTENT_FILTERED response case. ### Changed - **BREAKING**: Make `index_name` on `MongoDbAtlasVectorStoreDriver` a required field. - **BREAKING**: Remove `create_index()` from `MarqoVectorStoreDriver`, `OpenSearchVectorStoreDriver`, `PineconeVectorStoreDriver`, `RedisVectorStoreDriver`. - **BREAKING**: `ImageLoader().load()` now accepts image bytes instead of a file path. +- **BREAKING**: Request parameters for `BedrockStableDiffusionImageGenerationModelDriver` have been updated for `stability.stable-diffusion-xl-v1`. Use this over the now deprecated `stability.stable-diffusion-xl-v0`. - Deprecated `Structure.prompt_driver` in favor of `Structure.config.global_drivers.prompt_driver`. - Deprecated `Structure.embedding_driver` in favor of `Structure.config.global_drivers.embedding_driver`. - Deprecated `Structure.stream` in favor of `Structure.config.global_drivers.prompt_driver.stream`. @@ -147,7 +165,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.22.2] - 2024-01-18 ### Fixed -- `ToolkitTask`'s user subtask prompt occassionally causing a loop with Chain of Thought. +- `ToolkitTask`'s user subtask prompt occasionally causing a loop with Chain of Thought. ### Security - Updated stale dependencies [CVE-2023-50447, CVE-2024-22195, and CVE-2023-36464] diff --git a/README.md b/README.md index 847c9d9c9..53c941a2e 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,8 @@ Tools provide capabilities for LLMs to interact with data and services. Griptape Drivers facilitate interactions with external resources and services: -- 🔢 **Prompt and Embedding Drivers** generate vector embeddings from textual inputs. +- 🗣️ **Prompt Drivers** manage textual interactions with LLMs. +- 🔢 **Embedding Drivers** generate vector embeddings from textual inputs. - 💾 **Vector Store Drivers** manage the storage and retrieval of embeddings. - 🎨 **Image Generation Drivers** create images from text descriptions. - 🔎 **Image Query Drivers** query images from text queries. diff --git a/docs/examples/multi-agent-workflow.md b/docs/examples/multi-agent-workflow.md new file mode 100644 index 000000000..0e85b9bde --- /dev/null +++ b/docs/examples/multi-agent-workflow.md @@ -0,0 +1,191 @@ +In this example we implement a multi-agent Workflow. We have a single "Researcher" Agent that conducts research on a topic, and then fans out to multiple "Writer" Agents to write blog posts based on the research. + +By splitting up our workloads across multiple Structures, we can parallelize the work and leverage the strengths of each Agent. The Researcher can focus on gathering data and insights, while the Writers can focus on crafting engaging narratives. +Additionally, this architecture opens us up to using services such as [Griptape Cloud](https://www.griptape.ai/cloud) to have each Agent run on a separate machine, allowing us to scale our Workflow as needed 🤯. + + +```python +import os + +from griptape.drivers import WebhookEventListenerDriver, LocalStructureRunDriver +from griptape.events import EventListener, FinishStructureRunEvent +from griptape.rules import Rule, Ruleset +from griptape.structures import Agent, Workflow +from griptape.tasks import PromptTask, StructureRunTask +from griptape.tools import ( + TaskMemoryClient, + WebScraper, + WebSearch, +) + +WRITERS = [ + { + "role": "Travel Adventure Blogger", + "goal": "Inspire wanderlust with stories of hidden gems and exotic locales", + "backstory": "With a passport full of stamps, you bring distant cultures and breathtaking scenes to life through vivid storytelling and personal anecdotes.", + }, + { + "role": "Lifestyle Freelance Writer", + "goal": "Share practical advice on living a balanced and stylish life", + "backstory": "From the latest trends in home decor to tips for wellness, your articles help readers create a life that feels both aspirational and attainable.", + }, +] + + +def build_researcher(): + """Builds a Researcher Structure.""" + researcher = Agent( + id="researcher", + tools=[ + WebSearch( + google_api_key=os.environ["GOOGLE_API_KEY"], + google_api_search_id=os.environ["GOOGLE_API_SEARCH_ID"], + off_prompt=False, + ), + WebScraper( + off_prompt=True, + ), + TaskMemoryClient(off_prompt=False), + ], + rulesets=[ + Ruleset( + name="Position", + rules=[ + Rule( + value="Lead Research Analyst", + ) + ], + ), + Ruleset( + name="Objective", + rules=[ + Rule( + value="Discover innovative advancements in artificial intelligence and data analytics", + ) + ], + ), + Ruleset( + name="Background", + rules=[ + Rule( + value="""You are part of a prominent technology research institute. + Your speciality is spotting new trends. + You excel at analyzing intricate data and delivering practical insights.""" + ) + ], + ), + Ruleset( + name="Desired Outcome", + rules=[ + Rule( + value="Comprehensive analysis report in list format", + ) + ], + ), + ], + ) + + return researcher + + +def build_writer(role: str, goal: str, backstory: str): + """Builds a Writer Structure. + + Args: + role: The role of the writer. + goal: The goal of the writer. + backstory: The backstory of the writer. + """ + writer = Agent( + id=role.lower().replace(" ", "_"), + event_listeners=[ + EventListener( + event_types=[FinishStructureRunEvent], + driver=WebhookEventListenerDriver( + webhook_url=os.environ["WEBHOOK_URL"], + ), + ) + ], + rulesets=[ + Ruleset( + name="Position", + rules=[ + Rule( + value=role, + ) + ], + ), + Ruleset( + name="Objective", + rules=[ + Rule( + value=goal, + ) + ], + ), + Ruleset( + name="Backstory", + rules=[Rule(value=backstory)], + ), + Ruleset( + name="Desired Outcome", + rules=[ + Rule( + value="Full blog post of at least 4 paragraphs", + ) + ], + ), + ], + ) + + return writer + + +if __name__ == "__main__": + # Build the team + team = Workflow() + research_task = team.add_task( + StructureRunTask( + ( + """Perform a detailed examination of the newest developments in AI as of 2024. + Pinpoint major trends, breakthroughs, and their implications for various industries.""", + ), + id="research", + driver=LocalStructureRunDriver( + structure_factory_fn=build_researcher, + ), + ), + ) + end_task = team.add_task( + PromptTask( + 'State "All Done!"', + ) + ) + team.insert_tasks( + research_task, + [ + StructureRunTask( + ( + """Using insights provided, develop an engaging blog + post that highlights the most significant AI advancements. + Your post should be informative yet accessible, catering to a tech-savvy audience. + Make it sound cool, avoid complex words so it doesn't sound like AI. + + Insights: + {{ parent_outputs["research"] }}""", + ), + driver=LocalStructureRunDriver( + structure_factory_fn=lambda: build_writer( + role=writer["role"], + goal=writer["goal"], + backstory=writer["backstory"], + ) + ), + ) + for writer in WRITERS + ], + end_task, + ) + + team.run() +``` diff --git a/docs/griptape-cloud/api/api-reference.md b/docs/griptape-cloud/api/api-reference.md new file mode 100644 index 000000000..1b49f33d8 --- /dev/null +++ b/docs/griptape-cloud/api/api-reference.md @@ -0,0 +1 @@ +# Content overridden by Swagger Plugin diff --git a/docs/griptape-cloud/index.md b/docs/griptape-cloud/index.md new file mode 100644 index 000000000..6a841dccc --- /dev/null +++ b/docs/griptape-cloud/index.md @@ -0,0 +1,3 @@ +# Griptape Cloud + +Griptape Cloud is a managed platform for running AI-powered agents, pipelines, and workflows. \ No newline at end of file diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 8aac96d6a..d767d0d8f 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -13,17 +13,24 @@ Inherits from the [TextLoader](../../reference/griptape/loaders/text_loader.md) ```python from griptape.loaders import PdfLoader +from griptape.utils import load_files, load_file import urllib.request urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "attention.pdf") +# Load a single PDF file with open("attention.pdf", "rb") as f: PdfLoader().load(f.read()) +# You can also use the load_file utility function +PdfLoader().load(load_file("attention.pdf")) urllib.request.urlretrieve("https://arxiv.org/pdf/1706.03762.pdf", "CoT.pdf") +# Load multiple PDF files with open("attention.pdf", "rb") as attention, open("CoT.pdf", "rb") as cot: PdfLoader().load_collection([attention.read(), cot.read()]) +# You can also use the load_files utility function +PdfLoader().load_collection(list(load_files(["attention.pdf", "CoT.pdf"]).values())) ``` ## Sql Loader @@ -53,12 +60,19 @@ Can be used to load CSV files into [CsvRowArtifact](../../reference/griptape/art ```python from griptape.loaders import CsvLoader +from griptape.utils import load_file, load_files +# Load a single CSV file with open("tests/resources/cities.csv", "r") as f: CsvLoader().load(f.read()) +# You can also use the load_file utility function +CsvLoader().load(load_file("tests/resources/cities.csv")) +# Load multiple CSV files with open("tests/resources/cities.csv", "r") as cities, open("tests/resources/addresses.csv", "r") as addresses: CsvLoader().load_collection([cities.read(), addresses.read()]) +# You can also use the load_files utility function +CsvLoader().load_collection(list(load_files(["tests/resources/cities.csv", "tests/resources/addresses.csv"]).values())) ``` @@ -140,19 +154,32 @@ The Image Loader is used to load an image as an [ImageArtifact](./artifacts.md#i ```python from griptape.loaders import ImageLoader +from griptape.utils import load_file +# Load an image from disk with open("tests/resources/mountain.png", "rb") as f: disk_image_artifact = ImageLoader().load(f.read()) +# You can also use the load_file utility function +ImageLoader().load(load_file("tests/resources/mountain.png")) ``` By default, the Image Loader will load images in their native format, but not all models work on all formats. To normalize the format of Artifacts returned by the Loader, set the `format` field. ```python from griptape.loaders import ImageLoader +from griptape.utils import load_files, load_file -# Image data in artifact will be in BMP format. +# Load a single image in BMP format with open("tests/resources/mountain.png", "rb") as f: image_artifact_jpeg = ImageLoader(format="bmp").load(f.read()) +# You can also use the load_file utility function +ImageLoader(format="bmp").load(load_file("tests/resources/mountain.png")) + +# Load multiple images in BMP format +with open("tests/resources/mountain.png", "rb") as mountain, open("tests/resources/cow.png", "rb") as cow: + ImageLoader().load_collection([mountain.read(), cow.read()]) +# You can also use the load_files utility function +ImageLoader().load_collection(list(load_files(["tests/resources/mountain.png", "tests/resources/cow.png"]).values())) ``` diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md new file mode 100644 index 000000000..da9a6c05d --- /dev/null +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -0,0 +1,217 @@ +## Overview + +Event Listener Drivers are used to send Griptape [Events](../misc/events.md) to external services. + +You can instantiate Drivers and pass them to Event Listeners in your Structure: + +```python +import os + +from griptape.drivers import AmazonSqsEventListenerDriver +from griptape.events import ( + EventListener, +) +from griptape.rules import Rule +from griptape.structures import Agent + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) + ], + event_listeners=[ + EventListener( + handler=lambda event: { # You can optionally use the handler to transform the event payload before sending it to the Driver + "event": event.to_dict(), + }, + driver=AmazonSqsEventListenerDriver( + queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], + ), + ), + ], +) + +agent.run( + """Black-on-black ware is a 20th- and 21st-century pottery tradition developed by the Puebloan Native American ceramic artists in Northern New Mexico. + Traditional reduction-fired blackware has been made for centuries by pueblo artists. + Black-on-black ware of the past century is produced with a smooth surface, with the designs applied through selective burnishing or the application of refractory slip. + Another style involves carving or incising designs and selectively polishing the raised areas. + For generations several families from Kha'po Owingeh and P'ohwhóge Owingeh pueblos have been making black-on-black ware with the techniques passed down from matriarch potters. Artists from other pueblos have also produced black-on-black ware. + Several contemporary artists have created works honoring the pottery of their ancestors.""" +) +``` + +Or use them independently: + +```python +import os +from griptape.drivers import GriptapeCloudEventListenerDriver +from griptape.events import FinishStructureRunEvent +from griptape.artifacts import TextArtifact + +event_driver = GriptapeCloudEventListenerDriver( + api_key=os.environ["GRIPTAPE_CLOUD_API_KEY"] +) + +done_event = FinishStructureRunEvent( + output_task_input=TextArtifact("Just started!"), + output_task_output=TextArtifact("All done!"), +) + +event_driver.publish_event(done_event) +``` + +## Event Listener Drivers + +Griptape offers the following Event Listener Drivers for forwarding Griptape Events. + +### Amazon SQS Event Listener Driver + +!!! info + This driver requires the `drivers-event-listener-amazon-sqs` [extra](../index.md#extras). + +The [AmazonSqsEventListenerDriver](../../reference/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.md) sends Events to an [Amazon SQS](https://aws.amazon.com/sqs/) queue. + +```python +import os + +from griptape.drivers import AmazonSqsEventListenerDriver +from griptape.events import ( + EventListener, +) +from griptape.rules import Rule +from griptape.structures import Agent + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) + ], + event_listeners=[ + EventListener( + driver=AmazonSqsEventListenerDriver( + queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], + ), + ), + ], +) + +agent.run( + """Black-on-black ware is a 20th- and 21st-century pottery tradition developed by the Puebloan Native American ceramic artists in Northern New Mexico. + Traditional reduction-fired blackware has been made for centuries by pueblo artists. + Black-on-black ware of the past century is produced with a smooth surface, with the designs applied through selective burnishing or the application of refractory slip. + Another style involves carving or incising designs and selectively polishing the raised areas. + For generations several families from Kha'po Owingeh and P'ohwhóge Owingeh pueblos have been making black-on-black ware with the techniques passed down from matriarch potters. Artists from other pueblos have also produced black-on-black ware. + Several contemporary artists have created works honoring the pottery of their ancestors.""" +) +``` + +### AWS IoT Event Listener Driver + +!!! info + This driver requires the `drivers-event-listener-amazon-iot` [extra](../index.md#extras). + +The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.md) sends Events to the [AWS IoT Message Broker](https://aws.amazon.com/iot-core/). + +```python +import os + +from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver +from griptape.events import ( + EventListener, + FinishStructureRunEvent, +) +from griptape.rules import Rule +from griptape.structures import Agent + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a text, and your task is to extract the airport codes from it." + ) + ], + config=StructureConfig( + global_drivers=StructureGlobalDriversConfig( + prompt_driver=OpenAiChatPromptDriver( + model="gpt-3.5-turbo", temperature=0.7 + ), + ) + ), + event_listeners=[ + EventListener( + event_types=[FinishStructureRunEvent], + driver=AwsIotCoreEventListenerDriver( + topic=os.environ["AWS_IOT_CORE_TOPIC"], + iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], + ), + ), + ], +) + +agent.run("I want to fly from Orlando to Boston") +``` + +### Griptape Cloud Event Listener Driver + +The [GriptapeCloudEventListenerDriver](../../reference/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.md) sends Events to [Griptape Cloud](https://www.griptape.ai/cloud). + +!!! note + This Driver is required when using the Griptape Cloud Managed Structures feature. For local development, you can use the [Skatepark Emulator](https://github.com/griptape-ai/griptape-cli?tab=readme-ov-file#skatepark-emulator). + +```python +import os + +from griptape.drivers import GriptapeCloudEventListenerDriver +from griptape.events import ( + EventListener, + FinishStructureRunEvent, +) +from griptape.structures import Agent + +agent = Agent( + event_listeners=[ + EventListener( + event_types=[FinishStructureRunEvent], + driver=GriptapeCloudEventListenerDriver( + api_key=os.environ["GRIPTAPE_CLOUD_API_KEY"], + ), + ), + ], +) + +agent.run( + "Create a list of 8 questions for an interview with a science fiction author." +) +``` + +### Webhook Event Listener Driver + +The [WebhookEventListenerDriver](../../reference/griptape/drivers/event_listener/webhook_event_listener_driver.md) sends Events to any [Webhook](https://en.wikipedia.org/wiki/Webhook) URL. + +```python +import os + +from griptape.drivers import WebhookEventListenerDriver +from griptape.events import ( + EventListener, + FinishStructureRunEvent, +) +from griptape.structures import Agent + +agent = Agent( + event_listeners=[ + EventListener( + event_types=[FinishStructureRunEvent], + driver=WebhookEventListenerDriver( + webhook_url=os.environ["WEBHOOK_URL"], + ), + ), + ], +) + +agent.run("Analyze the pros and cons of remote work vs. office work") +``` + diff --git a/docs/griptape-framework/drivers/structure-run-drivers.md b/docs/griptape-framework/drivers/structure-run-drivers.md new file mode 100644 index 000000000..c2b94190c --- /dev/null +++ b/docs/griptape-framework/drivers/structure-run-drivers.md @@ -0,0 +1,103 @@ +## Overview +Structure Run Drivers can be used to run Griptape Structures in a variety of runtime environments. +When combined with the [Structure Run Task](../../griptape-framework/structures/tasks.md#structure-run-task) or [Structure Run Client](../../griptape-tools/official-tools/structure-run-client.md) you can create complex, multi-agent pipelines that span multiple runtime environments. + +## Local Structure Run Driver + +The [LocalStructureRunDriver](../../reference/griptape/drivers/structure_run/local_structure_run_driver.md) is used to run Griptape Structures in the same runtime environment as the code that is running the Structure. + +```python +from griptape.drivers import LocalStructureRunDriver +from griptape.rules import Rule +from griptape.structures import Agent, Pipeline +from griptape.tasks import StructureRunTask + +def build_joke_teller(): + joke_teller = Agent( + rules=[ + Rule( + value="You are very funny.", + ) + ], + ) + + return joke_teller + +def build_joke_rewriter(): + joke_rewriter = Agent( + rules=[ + Rule( + value="You are the editor of a joke book. But you only speak in riddles", + ) + ], + ) + + return joke_rewriter + +joke_coordinator = Pipeline( + tasks=[ + StructureRunTask( + driver=LocalStructureRunDriver( + structure_factory_fn=build_joke_teller, + ), + ), + StructureRunTask( + ("Rewrite this joke: {{ parent_output }}",), + driver=LocalStructureRunDriver( + structure_factory_fn=build_joke_rewriter, + ), + ), + ] +) + +joke_coordinator.run("Tell me a joke") +``` + +## Griptape Cloud Structure Run Driver + +The [GriptapeCloudStructureRunDriver](../../reference/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.md) is used to run Griptape Structures in the Griptape Cloud. + + +```python +import os + +from griptape.drivers import GriptapeCloudStructureRunDriver, LocalStructureRunDriver +from griptape.structures import Pipeline, Agent +from griptape.rules import Rule +from griptape.tasks import StructureRunTask + +base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"] +api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"] +structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"] + + +pipeline = Pipeline( + tasks=[ + StructureRunTask( + ("Think of a question related to Retrieval Augmented Generation.",), + driver=LocalStructureRunDriver( + structure_factory_fn=lambda: Agent( + rules=[ + Rule( + value="You are an expert in Retrieval Augmented Generation.", + ), + Rule( + value="Only output your answer, no other information.", + ), + ] + ) + ), + ), + StructureRunTask( + ("{{ parent_output }}",), + driver=GriptapeCloudStructureRunDriver( + base_url=base_url, + api_key=api_key, + structure_id=structure_id, + ), + ), + ] +) + +pipeline.run() +``` diff --git a/docs/griptape-framework/engines/image-generation-engines.md b/docs/griptape-framework/engines/image-generation-engines.md index e5e7c77e1..0c3997fa9 100644 --- a/docs/griptape-framework/engines/image-generation-engines.md +++ b/docs/griptape-framework/engines/image-generation-engines.md @@ -21,7 +21,7 @@ from griptape.rules import Ruleset, Rule # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( image_generation_model_driver=BedrockStableDiffusionImageGenerationModelDriver(), - model="stability.stable-diffusion-xl-v0", + model="stability.stable-diffusion-xl-v1", ) # Create an engine configured to use the driver. @@ -52,7 +52,7 @@ from griptape.drivers import AmazonBedrockImageGenerationDriver, \ # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( image_generation_model_driver=BedrockStableDiffusionImageGenerationModelDriver(), - model="stability.stable-diffusion-xl-v0", + model="stability.stable-diffusion-xl-v1", ) # Create an engine configured to use the driver. @@ -78,7 +78,7 @@ from griptape.loaders import ImageLoader # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( image_generation_model_driver=BedrockStableDiffusionImageGenerationModelDriver(), - model="stability.stable-diffusion-xl-v0", + model="stability.stable-diffusion-xl-v1", ) # Create an engine configured to use the driver. @@ -109,7 +109,7 @@ from griptape.loaders import ImageLoader # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( image_generation_model_driver=BedrockStableDiffusionImageGenerationModelDriver(), - model="stability.stable-diffusion-xl-v0", + model="stability.stable-diffusion-xl-v1", ) # Create an engine configured to use the driver. @@ -143,7 +143,7 @@ from griptape.loaders import ImageLoader # Create a driver configured to use Stable Diffusion via Bedrock. driver = AmazonBedrockImageGenerationDriver( image_generation_model_driver=BedrockStableDiffusionImageGenerationModelDriver(), - model="stability.stable-diffusion-xl-v0", + model="stability.stable-diffusion-xl-v1", ) # Create an engine configured to use the driver. diff --git a/docs/griptape-framework/index.md b/docs/griptape-framework/index.md index 7c33e0588..decf00da5 100644 --- a/docs/griptape-framework/index.md +++ b/docs/griptape-framework/index.md @@ -22,7 +22,7 @@ By default, Griptape uses [OpenAI Completions API](https://platform.openai.com/d Install **griptape**: ``` -pip install griptape[all] -U +pip install "griptape[all]" -U ``` ### Using Poetry @@ -36,7 +36,7 @@ poetry new griptape-quickstart Change your working directory to the new `griptape-quickstart` directory created by Poetry and add the the `griptape` dependency. ``` -poetry add griptape[all] +poetry add "griptape[all]" ``` ### Extras @@ -56,7 +56,7 @@ poetry add griptape To install specific extras (e.g., drivers for [AnthropicPromptDriver](./drivers/prompt-drivers.md#anthropic) and [PineconeVectorStoreDriver](./drivers/vector-store-drivers.md#pinecone)): ``` -poetry add griptape[drivers-prompt-anthropic,drivers-vector-pinecone] +poetry add "griptape[drivers-prompt-anthropic,drivers-vector-pinecone]" ``` For a comprehensive list of extras, please refer to the `[tool.poetry.extras]` section of Griptape's [pyproject.toml](https://github.com/griptape-ai/griptape/blob/main/pyproject.toml). diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index bd43f2404..f45f77199 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -1,6 +1,7 @@ ## Overview You can use [EventListener](../../reference/griptape/events/event_listener.md)s to listen for events during a Structure's execution. +See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -18,7 +19,6 @@ from griptape.events import ( FinishPromptEvent, EventListener, ) -from griptape.drivers import LocalEventListenerDriver def handler(event: BaseEvent): @@ -28,6 +28,7 @@ def handler(event: BaseEvent): agent = Agent( event_listeners=[ EventListener( + handler, event_types=[ StartTaskEvent, FinishTaskEvent, @@ -36,7 +37,6 @@ agent = Agent( StartPromptEvent, FinishPromptEvent, ], - driver=LocalEventListenerDriver(handler=handler), ) ] ) @@ -65,7 +65,6 @@ Or listen to all events: ```python from griptape.structures import Agent from griptape.events import BaseEvent, EventListener -from griptape.drivers import LocalEventListenerDriver def handler1(event: BaseEvent): @@ -78,8 +77,8 @@ def handler2(event: BaseEvent): agent = Agent( event_listeners=[ - EventListener(driver=LocalEventListenerDriver(handler=handler1)), - EventListener(driver=LocalEventListenerDriver(handler=handler1)), + EventListener(handler1), + EventListener(handler2), ] ) @@ -131,13 +130,12 @@ from griptape.events import CompletionChunkEvent, EventListener from griptape.tasks import ToolkitTask from griptape.structures import Pipeline from griptape.tools import WebScraper, TaskMemoryClient -from griptape.drivers import LocalEventListenerDriver pipeline = Pipeline( event_listeners=[ EventListener( - driver=LocalEventListenerDriver(handler=lambda e: print(e.token, end="", flush=True)), + lambda e: print(e.token, end="", flush=True), event_types=[CompletionChunkEvent], ) ] @@ -180,7 +178,6 @@ To count tokens, you can use Event Listeners and the [TokenCounter](../../refere from griptape import utils from griptape.events import BaseEvent, StartPromptEvent, FinishPromptEvent, EventListener from griptape.structures import Agent -from griptape.drivers import LocalEventListenerDriver token_counter = utils.TokenCounter() @@ -194,7 +191,7 @@ def count_tokens(e: BaseEvent): agent = Agent( event_listeners=[ EventListener( - driver=LocalEventListenerDriver(handler=lambda e: count_tokens(e)), + handler=lambda e: count_tokens(e), event_types=[StartPromptEvent, FinishPromptEvent], ) ] @@ -238,7 +235,6 @@ You can use the [StartPromptEvent](../../reference/griptape/events/start_prompt_ ```python from griptape.structures import Agent from griptape.events import BaseEvent, StartPromptEvent, EventListener -from griptape.drivers import LocalEventListenerDriver def handler(event: BaseEvent): @@ -251,7 +247,7 @@ def handler(event: BaseEvent): agent = Agent( - event_listeners=[EventListener(driver=LocalEventListenerDriver(handler=handler), event_types=[StartPromptEvent])] + event_listeners=[EventListener(handler=handler, event_types=[StartPromptEvent])] ) agent.run("Write me a poem.") diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index a7aae8734..c7f9e78cd 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -667,3 +667,143 @@ pipeline.add_task( pipeline.run("Describe the weather in the image") ``` + +## Structure Run Task +The [Structure Run Task](../../reference/griptape/tasks/structure_run_task.md) executes another Structure with a given input. +This Task is useful for orchestrating multiple specialized Structures in a single run. Note that the input to the Task is a tuple of arguments that will be passed to the Structure. + +```python +import os + +from griptape.rules import Rule, Ruleset +from griptape.structures import Agent, Pipeline +from griptape.tasks import StructureRunTask +from griptape.drivers import LocalStructureRunDriver +from griptape.tools import ( + TaskMemoryClient, + WebScraper, + WebSearch, +) + + +def build_researcher(): + researcher = Agent( + tools=[ + WebSearch( + google_api_key=os.environ["GOOGLE_API_KEY"], + google_api_search_id=os.environ["GOOGLE_API_SEARCH_ID"], + off_prompt=False, + ), + WebScraper( + off_prompt=True, + ), + TaskMemoryClient(off_prompt=False), + ], + rulesets=[ + Ruleset( + name="Position", + rules=[ + Rule( + value="Senior Research Analyst", + ) + ], + ), + Ruleset( + name="Objective", + rules=[ + Rule( + value="Uncover cutting-edge developments in AI and data science", + ) + ], + ), + Ruleset( + name="Background", + rules=[ + Rule( + value="""You work at a leading tech think tank., + Your expertise lies in identifying emerging trends. + You have a knack for dissecting complex data and presenting actionable insights.""" + ) + ], + ), + Ruleset( + name="Desired Outcome", + rules=[ + Rule( + value="Full analysis report in bullet points", + ) + ], + ), + ], + ) + + return researcher + + +def build_writer(): + writer = Agent( + input_template="Instructions: {{args[0]}}\nContext: {{args[1]}}", + rulesets=[ + Ruleset( + name="Position", + rules=[ + Rule( + value="Tech Content Strategist", + ) + ], + ), + Ruleset( + name="Objective", + rules=[ + Rule( + value="Craft compelling content on tech advancements", + ) + ], + ), + Ruleset( + name="Backstory", + rules=[ + Rule( + value="""You are a renowned Content Strategist, known for your insightful and engaging articles. + You transform complex concepts into compelling narratives.""" + ) + ], + ), + Ruleset( + name="Desired Outcome", + rules=[ + Rule( + value="Full blog post of at least 4 paragraphs", + ) + ], + ), + ], + ) + + return writer + + +team = Pipeline( + tasks=[ + StructureRunTask( + ( + """Perform a detailed examination of the newest developments in AI as of 2024. + Pinpoint major trends, breakthroughs, and their implications for various industries.""", + ), + driver=LocalStructureRunDriver(structure_factory_fn=build_researcher), + ), + StructureRunTask( + ( + """Utilize the gathered insights to craft a captivating blog + article showcasing the key AI innovations. + Ensure the content is engaging yet straightforward, appealing to a tech-aware readership. + Keep the tone appealing and use simple language to make it less technical.""", + "{{parent_output}}", + ), + driver=LocalStructureRunDriver(structure_factory_fn=build_writer), + ), + ], +) + +team.run() +``` diff --git a/docs/griptape-tools/official-tools/structure-run-client.md b/docs/griptape-tools/official-tools/structure-run-client.md new file mode 100644 index 000000000..863e02727 --- /dev/null +++ b/docs/griptape-tools/official-tools/structure-run-client.md @@ -0,0 +1,64 @@ +# StructureRunClient + +The StructureRunClient Tool provides a way to run Structures via a Tool. +It requires you to provide a [Structure Run Driver](../../griptape-framework/drivers/structure-run-drivers.md) to run the Structure in the desired environment. + +```python +import os + +from griptape.drivers import GriptapeCloudStructureRunDriver +from griptape.structures import Agent +from griptape.tools import StructureRunClient + +base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"] +api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"] +structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"] + +structure_run_tool = StructureRunClient( + description="RAG Expert Agent - Structure to invoke with natural language queries about the topic of Retrieval Augmented Generation", + driver=GriptapeCloudStructureRunDriver( + base_url=base_url, + api_key=api_key, + structure_id=structure_id, + ), + off_prompt=False, +) + +# Set up an agent using the StructureRunClient tool +agent = Agent(tools=[structure_run_tool]) + +# Task: Ask the Griptape Cloud Hosted Structure about modular RAG +agent.run("what is modular RAG?") +``` +``` +[05/02/24 13:50:03] INFO ToolkitTask 4e9458375bda4fbcadb77a94624ed64c + Input: what is modular RAG? +[05/02/24 13:50:10] INFO Subtask 5ef2d72028fc495aa7faf6f46825b004 + Thought: To answer this question, I need to run a search for the term "modular RAG". I will use the StructureRunClient action to execute a + search structure. + Actions: [ + { + "name": "StructureRunClient", + "path": "run_structure", + "input": { + "values": { + "args": "modular RAG" + } + }, + "tag": "search_modular_RAG" + } + ] +[05/02/24 13:50:36] INFO Subtask 5ef2d72028fc495aa7faf6f46825b004 + Response: {'id': '87fa21aded76416e988f8bf39c19760b', 'name': '87fa21aded76416e988f8bf39c19760b', 'type': 'TextArtifact', 'value': 'Modular + Retrieval-Augmented Generation (RAG) is an advanced approach that goes beyond the traditional RAG paradigms, offering enhanced adaptability + and versatility. It involves incorporating diverse strategies to improve its components by adding specialized modules for retrieval and + processing capabilities. The Modular RAG framework allows for module substitution or reconfiguration to address specific challenges, expanding + flexibility by integrating new modules or adjusting interaction flow among existing ones. This approach supports both sequential processing + and integrated end-to-end training across its components, illustrating progression and refinement within the RAG family.'} +[05/02/24 13:50:44] INFO ToolkitTask 4e9458375bda4fbcadb77a94624ed64c + Output: Modular Retrieval-Augmented Generation (RAG) is an advanced approach that goes beyond the traditional RAG paradigms, offering enhanced + adaptability and versatility. It involves incorporating diverse strategies to improve its components by adding specialized modules for + retrieval and processing capabilities. The Modular RAG framework allows for module substitution or reconfiguration to address specific + challenges, expanding flexibility by integrating new modules or adjusting interaction flow among existing ones. This approach supports both + sequential processing and integrated end-to-end training across its components, illustrating progression and refinement within the RAG family. +``` diff --git a/docs/index.md b/docs/index.md index 727761f87..5d22224e7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,6 +6,10 @@ Welcome to Griptape Docs! This documentation is organized into the following sec Griptape Topic Guides discuss key topics at a high level and provide useful background information and explanation. +### Griptape Cloud + +[Griptape Cloud](griptape-cloud/api/api-reference.md) provides an overview of the APIs available in the managed cloud service. + ### Griptape Framework [Griptape Framework](griptape-framework/index.md) provides an overview of the key topics within Griptape, and how you can get started building agents. diff --git a/docs/plugins/swagger_ui_plugin.py b/docs/plugins/swagger_ui_plugin.py index 0243f6900..6d5fb52da 100644 --- a/docs/plugins/swagger_ui_plugin.py +++ b/docs/plugins/swagger_ui_plugin.py @@ -5,7 +5,7 @@ from markupsafe import Markup config_scheme = { - "spec_url": "https://cloud-preview.griptape.ai/public/openapi.yaml", + "spec_url": "https://griptape-cloud-assets.s3.amazonaws.com/Griptape.openapi.yaml", "template": "swagger.md.tmpl", "outfile": "griptape-cloud/api/api-reference.md", } diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 283fca2d1..64c32ecec 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -24,7 +24,7 @@ class OpenAiStructureConfig(BaseStructureConfig): global_drivers: StructureGlobalDriversConfig = field( default=Factory( lambda: StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"), image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"), embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"), diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 462da47c4..f005e4670 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -91,7 +91,6 @@ from .event_listener.webhook_event_listener_driver import WebhookEventListenerDriver from .event_listener.aws_iot_core_event_listener_driver import AwsIotCoreEventListenerDriver from .event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver -from .event_listener.local_event_listener_driver import LocalEventListenerDriver from .file_manager.base_file_manager_driver import BaseFileManagerDriver from .file_manager.local_file_manager_driver import LocalFileManagerDriver @@ -101,6 +100,10 @@ from .text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver from .text_to_speech.elevenlabs_text_to_speech_driver import ElevenLabsTextToSpeechDriver +from .structure_run.base_structure_run_driver import BaseStructureRunDriver +from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver +from .structure_run.local_structure_run_driver import LocalStructureRunDriver + __all__ = [ "BasePromptDriver", "OpenAiChatPromptDriver", @@ -181,11 +184,13 @@ "WebhookEventListenerDriver", "AwsIotCoreEventListenerDriver", "GriptapeCloudEventListenerDriver", - "LocalEventListenerDriver", "BaseFileManagerDriver", "LocalFileManagerDriver", "AmazonS3FileManagerDriver", "BaseTextToSpeechDriver", "DummyTextToSpeechDriver", "ElevenLabsTextToSpeechDriver", + "BaseStructureRunDriver", + "GriptapeCloudStructureRunDriver", + "LocalStructureRunDriver", ] diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py index 0db63726b..1c8132b67 100644 --- a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any import json +from typing import TYPE_CHECKING, Any from attr import Factory, define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver -from griptape.events.base_event import BaseEvent from griptape.utils import import_optional_dependency if TYPE_CHECKING: @@ -19,5 +18,13 @@ class AmazonSqsEventListenerDriver(BaseEventListenerDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True)) - def try_publish_event(self, event: BaseEvent) -> None: - self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps({"event": event.to_dict()})) + def try_publish_event_payload(self, event_payload: dict) -> None: + self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + entries = [ + {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)} + for event_payload in event_payload_batch + ] + + self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries) diff --git a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py index 876b790e8..c4fd72084 100644 --- a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py +++ b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py @@ -1,12 +1,11 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Any -import json from attr import Factory, define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver -from griptape.events.base_event import BaseEvent from griptape.utils import import_optional_dependency if TYPE_CHECKING: @@ -20,5 +19,8 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True)) - def try_publish_event(self, event: BaseEvent) -> None: - self.iotdata_client.publish(topic=self.topic, payload=json.dumps({"event": event.to_dict()})) + def try_publish_event_payload(self, event_payload: dict) -> None: + self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload)) + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch)) diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index 5bfbe6709..8e7f827e9 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -1,20 +1,50 @@ from __future__ import annotations + from abc import ABC, abstractmethod from concurrent import futures -from attr import define, field, Factory -from typing import TYPE_CHECKING +from logging import Logger + +from attr import Factory, define, field + +from griptape.events import BaseEvent -if TYPE_CHECKING: - from griptape.events import BaseEvent +logger = Logger(__name__) @define class BaseEventListenerDriver(ABC): futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) + batched: bool = field(default=True, kw_only=True) + batch_size: int = field(default=10, kw_only=True) + + _batch: list[dict] = field(default=Factory(list), kw_only=True) + + @property + def batch(self) -> list[dict]: + return self._batch - def publish_event(self, event: BaseEvent) -> None: - self.futures_executor.submit(self.try_publish_event, event) + def publish_event(self, event: BaseEvent | dict, flush: bool = False) -> None: + self.futures_executor.submit(self._safe_try_publish_event, event, flush) @abstractmethod - def try_publish_event(self, event: BaseEvent) -> None: + def try_publish_event_payload(self, event_payload: dict) -> None: ... + + @abstractmethod + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + ... + + def _safe_try_publish_event(self, event: BaseEvent | dict, flush: bool) -> None: + try: + event_payload = event if isinstance(event, dict) else event.to_dict() + + if self.batched: + self._batch.append(event_payload) + if len(self.batch) >= self.batch_size or flush: + self.try_publish_event_payload_batch(self.batch) + self._batch = [] + return + else: + self.try_publish_event_payload(event_payload) + except Exception as e: + logger.error(e) diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 51b62aeac..2c4149ae7 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -7,26 +7,43 @@ from attr import define, field, Factory from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver -from griptape.events.base_event import BaseEvent @define class GriptapeCloudEventListenerDriver(BaseEventListenerDriver): - base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) + """Driver for publishing events to Griptape Cloud. + + Attributes: + base_url: The base URL of Griptape Cloud. Defaults to the GT_CLOUD_BASE_URL environment variable. + api_key: The API key to authenticate with Griptape Cloud. + headers: The headers to use when making requests to Griptape Cloud. Defaults to include the Authorization header. + structure_run_id: The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_STRUCTURE_RUN_ID environment variable. + """ + + base_url: str = field( + default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True + ) api_key: str = field(kw_only=True) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True ) - run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_RUN_ID")), kw_only=True) + structure_run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True) - @run_id.validator # pyright: ignore - def validate_run_id(self, _, run_id: str): - if run_id is None: + @structure_run_id.validator # pyright: ignore + def validate_run_id(self, _, structure_run_id: str): + if structure_run_id is None: raise ValueError( - "run_id must be set either in the constructor or as an environment variable (GT_CLOUD_RUN_ID)." + "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)." ) - def try_publish_event(self, event: BaseEvent) -> None: - url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.run_id}/events") + def try_publish_event_payload(self, event_payload: dict) -> None: + url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events") + + response = requests.post(url=url, json=event_payload, headers=self.headers) + response.raise_for_status() + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events") - requests.post(url=url, json=event.to_dict(), headers=self.headers) + response = requests.post(url=url, json=event_payload_batch, headers=self.headers) + response.raise_for_status() diff --git a/griptape/drivers/event_listener/local_event_listener_driver.py b/griptape/drivers/event_listener/local_event_listener_driver.py deleted file mode 100644 index b276b94ef..000000000 --- a/griptape/drivers/event_listener/local_event_listener_driver.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Any -from attr import define, field - -from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver -from griptape.events.base_event import BaseEvent - - -@define -class LocalEventListenerDriver(BaseEventListenerDriver): - handler: Callable[[BaseEvent], Any] = field(default=None, kw_only=True) - - def publish_event(self, event: BaseEvent) -> None: - self.try_publish_event(event) - - def try_publish_event(self, event: BaseEvent) -> None: - self.handler(event) diff --git a/griptape/drivers/event_listener/webhook_event_listener_driver.py b/griptape/drivers/event_listener/webhook_event_listener_driver.py index d2f0046d0..242e5428a 100644 --- a/griptape/drivers/event_listener/webhook_event_listener_driver.py +++ b/griptape/drivers/event_listener/webhook_event_listener_driver.py @@ -5,7 +5,6 @@ from attr import define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver -from griptape.events.base_event import BaseEvent @define @@ -13,5 +12,10 @@ class WebhookEventListenerDriver(BaseEventListenerDriver): webhook_url: str = field(kw_only=True) headers: dict = field(default=None, kw_only=True) - def try_publish_event(self, event: BaseEvent) -> None: - requests.post(url=self.webhook_url, json={"event": event.to_dict()}, headers=self.headers) + def try_publish_event_payload(self, event_payload: dict) -> None: + response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers) + response.raise_for_status() + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers) + response.raise_for_status() diff --git a/griptape/drivers/structure_run/__init__.py b/griptape/drivers/structure_run/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/drivers/structure_run/base_structure_run_driver.py b/griptape/drivers/structure_run/base_structure_run_driver.py new file mode 100644 index 000000000..8e10fb231 --- /dev/null +++ b/griptape/drivers/structure_run/base_structure_run_driver.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod + +from attrs import define + +from griptape.artifacts import BaseArtifact + + +@define +class BaseStructureRunDriver(ABC): + def run(self, *args: BaseArtifact) -> BaseArtifact: + return self.try_run(*args) + + @abstractmethod + def try_run(self, *args: BaseArtifact) -> BaseArtifact: + ... diff --git a/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py new file mode 100644 index 000000000..9ed036995 --- /dev/null +++ b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import time +from typing import Any +from urllib.parse import urljoin + +from attrs import Factory, define, field + +from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact +from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver + + +@define +class GriptapeCloudStructureRunDriver(BaseStructureRunDriver): + base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) + api_key: str = field(kw_only=True) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True + ) + structure_id: str = field(kw_only=True) + structure_run_wait_time_interval: int = field(default=2, kw_only=True) + structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True) + async_run: bool = field(default=False, kw_only=True) + + def try_run(self, *args: BaseArtifact) -> BaseArtifact: + from requests import HTTPError, Response, exceptions, post + + url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs") + + try: + response: Response = post(url, json={"args": [arg.value for arg in args]}, headers=self.headers) + response.raise_for_status() + response_json = response.json() + + if self.async_run: + return InfoArtifact("Run started successfully") + else: + return self._get_structure_run_result(response_json["structure_run_id"]) + except (exceptions.RequestException, HTTPError) as err: + return ErrorArtifact(str(err)) + + def _get_structure_run_result(self, structure_run_id: str) -> InfoArtifact | TextArtifact | ErrorArtifact: + url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}") + + result = self._get_structure_run_result_attempt(url) + status = result["status"] + + wait_attempts = 0 + while status in ("QUEUED", "RUNNING") and wait_attempts < self.structure_run_max_wait_time_attempts: + # wait + time.sleep(self.structure_run_wait_time_interval) + wait_attempts += 1 + result = self._get_structure_run_result_attempt(url) + status = result["status"] + + if wait_attempts >= self.structure_run_max_wait_time_attempts: + return ErrorArtifact( + f"Failed to get Run result after {self.structure_run_max_wait_time_attempts} attempts." + ) + + if status != "SUCCEEDED": + return ErrorArtifact(result) + + if "output" in result: + return TextArtifact.from_dict(result["output"]) + else: + return InfoArtifact("No output found in response") + + def _get_structure_run_result_attempt(self, structure_run_url: str) -> Any: + from requests import get, Response + + response: Response = get(structure_run_url, headers=self.headers) + response.raise_for_status() + + return response.json() diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py new file mode 100644 index 000000000..255f91445 --- /dev/null +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -0,0 +1,23 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Callable + +from attrs import define, field + +from griptape.artifacts import BaseArtifact, InfoArtifact +from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver + +if TYPE_CHECKING: + from griptape.structures import Structure + + +@define +class LocalStructureRunDriver(BaseStructureRunDriver): + structure_factory_fn: Callable[[], Structure] = field(kw_only=True) + + def try_run(self, *args: BaseArtifact) -> BaseArtifact: + structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) + + if structure_factory_fn.output_task.output is not None: + return structure_factory_fn.output_task.output + else: + return InfoArtifact("No output found in response") diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index 0ee86cb4e..7b573ed68 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -17,8 +17,9 @@ class PineconeVectorStoreDriver(BaseVectorStoreDriver): index: pinecone.Index = field(init=False) def __attrs_post_init__(self) -> None: - pinecone = import_optional_dependency("pinecone") - pinecone.init(api_key=self.api_key, environment=self.environment, project_name=self.project_name) + pinecone = import_optional_dependency("pinecone").Pinecone( + api_key=self.api_key, environment=self.environment, project_name=self.project_name + ) self.index = pinecone.Index(self.index_name) @@ -34,7 +35,7 @@ def upsert_vector( params: dict[str, Any] = {"namespace": namespace} | kwargs - self.index.upsert([(vector_id, vector, meta)], **params) + self.index.upsert(vectors=[(vector_id, vector, meta)], **params) return vector_id @@ -57,7 +58,7 @@ def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreD # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5 results = self.index.query( - self.embedding_driver.embed_string(""), top_k=10000, include_metadata=True, namespace=namespace + vector=self.embedding_driver.embed_string(""), top_k=10000, include_metadata=True, namespace=namespace ) return [ @@ -86,7 +87,7 @@ def query( "include_metadata": include_metadata, } | kwargs - results = self.index.query(vector, **params) + results = self.index.query(vector=vector, **params) return [ BaseVectorStoreDriver.QueryResult( diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index ac38a05a3..7ca42f59f 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -61,7 +61,11 @@ def query( if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens: text_segments.pop() - user_message = self.user_template_generator.render(query=query, text_segments=text_segments) + system_message = self.system_template_generator.render( + rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), + metadata=metadata, + text_segments=text_segments, + ) break diff --git a/griptape/events/base_event.py b/griptape/events/base_event.py index d32defe96..48a48890e 100644 --- a/griptape/events/base_event.py +++ b/griptape/events/base_event.py @@ -1,11 +1,15 @@ from __future__ import annotations + import time +import uuid from abc import ABC -from attr import define, field, Factory + +from attr import Factory, define, field from griptape.mixins import SerializableMixin @define class BaseEvent(SerializableMixin, ABC): + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) timestamp: float = field(default=Factory(lambda: time.time()), kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index aa7da5dcd..44d7b2d85 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, TYPE_CHECKING -from attrs import define, field +from typing import Optional, TYPE_CHECKING, Callable +from attrs import define, field, Factory from .base_event import BaseEvent if TYPE_CHECKING: @@ -9,12 +9,17 @@ @define class EventListener: + handler: Callable[[BaseEvent], Optional[dict]] = field(default=Factory(lambda: lambda event: event.to_dict())) event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True) driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) - def publish_event(self, event: BaseEvent) -> None: + def publish_event(self, event: BaseEvent, flush: bool = False) -> None: event_types = self.event_types if event_types is None or type(event) in event_types: + event_payload = self.handler(event) if self.driver is not None: - self.driver.publish_event(event) + if event_payload is not None and isinstance(event_payload, dict): + self.driver.publish_event(event_payload, flush=flush) + else: + self.driver.publish_event(event, flush=flush) diff --git a/griptape/events/finish_structure_run_event.py b/griptape/events/finish_structure_run_event.py index 8591041e2..2a4688fb9 100644 --- a/griptape/events/finish_structure_run_event.py +++ b/griptape/events/finish_structure_run_event.py @@ -9,6 +9,7 @@ @define class FinishStructureRunEvent(BaseEvent): + structure_id: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) output_task_input: Union[ BaseArtifact, tuple[BaseArtifact, ...], tuple[BaseArtifact, Sequence[BaseArtifact]] ] = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/start_structure_run_event.py b/griptape/events/start_structure_run_event.py index a0780fd1f..311d687a9 100644 --- a/griptape/events/start_structure_run_event.py +++ b/griptape/events/start_structure_run_event.py @@ -9,6 +9,7 @@ @define class StartStructureRunEvent(BaseEvent): + structure_id: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) input_task_input: Union[ BaseArtifact, tuple[BaseArtifact, ...], tuple[BaseArtifact, Sequence[BaseArtifact]] ] = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index e807393f4..9cd28ab67 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -249,23 +249,28 @@ def remove_event_listener(self, event_listener: EventListener) -> None: else: raise ValueError("Event Listener not found.") - def publish_event(self, event: BaseEvent) -> None: + def publish_event(self, event: BaseEvent, flush: bool = False) -> None: for event_listener in self.event_listeners: - event_listener.publish_event(event) + event_listener.publish_event(event, flush) def context(self, task: BaseTask) -> dict[str, Any]: return {"args": self.execution_args, "structure": self} def before_run(self) -> None: self.publish_event( - StartStructureRunEvent(input_task_input=self.input_task.input, input_task_output=self.input_task.output) + StartStructureRunEvent( + structure_id=self.id, input_task_input=self.input_task.input, input_task_output=self.input_task.output + ) ) def after_run(self) -> None: self.publish_event( FinishStructureRunEvent( - output_task_input=self.output_task.input, output_task_output=self.output_task.output - ) + structure_id=self.id, + output_task_input=self.output_task.input, + output_task_output=self.output_task.output, + ), + flush=True, ) @abstractmethod diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index 110d7dbe2..e51335241 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -1,5 +1,6 @@ from .base_task import BaseTask from .base_text_input_task import BaseTextInputTask +from .base_multi_text_input_task import BaseMultiTextInputTask from .prompt_task import PromptTask from .actions_subtask import ActionsSubtask from .toolkit_task import ToolkitTask @@ -18,10 +19,12 @@ from .image_query_task import ImageQueryTask from .base_audio_generation_task import BaseAudioGenerationTask from .text_to_speech_task import TextToSpeechTask +from .structure_run_task import StructureRunTask __all__ = [ "BaseTask", "BaseTextInputTask", + "BaseMultiTextInputTask", "PromptTask", "ActionsSubtask", "ToolkitTask", @@ -40,4 +43,5 @@ "ImageQueryTask", "BaseAudioGenerationTask", "TextToSpeechTask", + "StructureRunTask", ] diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py new file mode 100644 index 000000000..eb00af6ca --- /dev/null +++ b/griptape/tasks/base_multi_text_input_task.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from abc import ABC +from typing import Callable + +from attr import define, field, Factory + +from griptape.artifacts import TextArtifact +from griptape.mixins.rule_mixin import RuleMixin +from griptape.tasks import BaseTask +from griptape.utils import J2 + + +@define +class BaseMultiTextInputTask(RuleMixin, BaseTask, ABC): + DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" + + _input: tuple[str, ...] | tuple[TextArtifact, ...] | tuple[Callable[[BaseTask], TextArtifact], ...] = field( + default=Factory(lambda self: (self.DEFAULT_INPUT_TEMPLATE,), takes_self=True), alias="input" + ) + + @property + def input(self) -> tuple[TextArtifact, ...]: + if all(isinstance(elem, TextArtifact) for elem in self._input): + return self._input # pyright: ignore + elif all(isinstance(elem, Callable) for elem in self._input): + return tuple([elem(self) for elem in self._input]) # pyright: ignore + elif isinstance(self._input, tuple): + return tuple( + [ + TextArtifact(J2().render_from_string(input_template, **self.full_context)) # pyright: ignore + for input_template in self._input + ] + ) + else: + return tuple([TextArtifact(J2().render_from_string(self._input, **self.full_context))]) + + @input.setter + def input( + self, value: tuple[str, ...] | tuple[TextArtifact, ...] | tuple[Callable[[BaseTask], TextArtifact], ...] + ) -> None: + self._input = value + + def before_run(self) -> None: + super().before_run() + + joined_input = "\n".join([input.to_text() for input in self.input]) + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {joined_input}") + + def after_run(self) -> None: + super().after_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") diff --git a/griptape/tasks/structure_run_task.py b/griptape/tasks/structure_run_task.py new file mode 100644 index 000000000..b7a8f9ea9 --- /dev/null +++ b/griptape/tasks/structure_run_task.py @@ -0,0 +1,22 @@ +from __future__ import annotations + + +from attr import define, field + +from griptape.artifacts import BaseArtifact +from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver +from griptape.tasks import BaseMultiTextInputTask + + +@define +class StructureRunTask(BaseMultiTextInputTask): + """Task to run a Structure. + + Attributes: + driver: Driver to run the Structure. + """ + + driver: BaseStructureRunDriver = field(kw_only=True) + + def run(self) -> BaseArtifact: + return self.driver.run(*self.input) diff --git a/griptape/tokenizers/openai_tokenizer.py b/griptape/tokenizers/openai_tokenizer.py index 08c334e0c..dda8bfe15 100644 --- a/griptape/tokenizers/openai_tokenizer.py +++ b/griptape/tokenizers/openai_tokenizer.py @@ -10,7 +10,7 @@ class OpenAiTokenizer(BaseTokenizer): DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "gpt-3.5-turbo-instruct" DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo" - DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4" + DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4o" DEFAULT_ENCODING = "cl100k_base" DEFAULT_MAX_TOKENS = 2049 DEFAULT_MAX_OUTPUT_TOKENS = 4096 @@ -18,6 +18,7 @@ class OpenAiTokenizer(BaseTokenizer): # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = { + "gpt-4o": 128000, "gpt-4-1106": 128000, "gpt-4-32k": 32768, "gpt-4": 8192, @@ -85,6 +86,7 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613", + "gpt-4o-2024-05-13", }: tokens_per_message = 3 tokens_per_name = 1 @@ -96,6 +98,9 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: logging.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") return self.count_tokens(text, model="gpt-3.5-turbo-0613") + elif "gpt-4o" in model: + logging.info("gpt-4o may update over time. Returning num tokens assuming gpt-4o-2024-05-13.") + return self.count_tokens(text, model="gpt-4o-2024-05-13") elif "gpt-4" in model: logging.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") return self.count_tokens(text, model="gpt-4-0613") diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 5e68ca3fc..0c9b6e01d 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -24,6 +24,7 @@ from .inpainting_image_generation_client.tool import InpaintingImageGenerationClient from .outpainting_image_generation_client.tool import OutpaintingImageGenerationClient from .griptape_cloud_knowledge_base_client.tool import GriptapeCloudKnowledgeBaseClient +from .structure_run_client.tool import StructureRunClient from .image_query_client.tool import ImageQueryClient __all__ = [ @@ -53,5 +54,6 @@ "InpaintingImageGenerationClient", "OutpaintingImageGenerationClient", "GriptapeCloudKnowledgeBaseClient", + "StructureRunClient", "ImageQueryClient", ] diff --git a/griptape/tools/base_griptape_cloud_client.py b/griptape/tools/base_griptape_cloud_client.py new file mode 100644 index 000000000..cafe01cd1 --- /dev/null +++ b/griptape/tools/base_griptape_cloud_client.py @@ -0,0 +1,20 @@ +from __future__ import annotations +from abc import ABC +from attr import Factory, define, field +from griptape.tools import BaseTool + + +@define +class BaseGriptapeCloudClient(BaseTool, ABC): + """ + Attributes: + base_url: Base URL for the Griptape Cloud Knowledge Base API. + api_key: API key for Griptape Cloud. + headers: Headers for the Griptape Cloud Knowledge Base API. + """ + + base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) + api_key: str = field(kw_only=True) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True + ) diff --git a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py b/griptape/tools/griptape_cloud_knowledge_base_client/tool.py index da988284f..89a284c98 100644 --- a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py +++ b/griptape/tools/griptape_cloud_knowledge_base_client/tool.py @@ -2,29 +2,21 @@ from typing import Optional from urllib.parse import urljoin from schema import Schema, Literal -from attr import define, field, Factory -from griptape.tools import BaseTool +from attr import define, field +from griptape.tools.base_griptape_cloud_client import BaseGriptapeCloudClient from griptape.utils.decorators import activity from griptape.artifacts import TextArtifact, ErrorArtifact @define -class GriptapeCloudKnowledgeBaseClient(BaseTool): +class GriptapeCloudKnowledgeBaseClient(BaseGriptapeCloudClient): """ Attributes: description: LLM-friendly knowledge base description. - base_url: Base URL for the Griptape Cloud Knowledge Base API. - api_key: API key for Griptape Cloud. - headers: Headers for the Griptape Cloud Knowledge Base API. knowledge_base_id: ID of the Griptape Cloud Knowledge Base. """ description: Optional[str] = field(default=None, kw_only=True) - base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) - api_key: str = field(kw_only=True) - headers: dict = field( - default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True - ) knowledge_base_id: str = field(kw_only=True) @activity( diff --git a/griptape/tools/structure_run_client/__init__.py b/griptape/tools/structure_run_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/structure_run_client/manifest.yml b/griptape/tools/structure_run_client/manifest.yml new file mode 100644 index 000000000..5f53158d8 --- /dev/null +++ b/griptape/tools/structure_run_client/manifest.yml @@ -0,0 +1,5 @@ +version: "v1" +name: Structure Run Client +description: Tool for running a Structure. +contact_email: hello@griptape.ai +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/structure_run_client/tool.py b/griptape/tools/structure_run_client/tool.py new file mode 100644 index 000000000..c62b53e97 --- /dev/null +++ b/griptape/tools/structure_run_client/tool.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from attr import define, field +from schema import Literal, Schema + +from griptape.artifacts import BaseArtifact, TextArtifact +from griptape.drivers import BaseStructureRunDriver +from griptape.tools.base_tool import BaseTool +from griptape.utils.decorators import activity + + +@define +class StructureRunClient(BaseTool): + """ + Attributes: + description: A description of what the Structure does. + driver: Driver to run the Structure. + """ + + description: str = field(kw_only=True) + driver: BaseStructureRunDriver = field(kw_only=True) + + @activity( + config={ + "description": "Can be used to run a Griptape Structure with the following description: {{ self.description }}", + "schema": Schema( + {Literal("args", description="A list of string arguments to submit to the Structure Run"): list} + ), + } + ) + def run_structure(self, params: dict) -> BaseArtifact: + args: list[str] = params["values"]["args"] + + return self.driver.run(*[TextArtifact(arg) for arg in args]) diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 1aad72db9..64ca9a9f7 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -8,8 +8,8 @@ from .futures import execute_futures_dict from .token_counter import TokenCounter from .prompt_stack import PromptStack -from .dict_utils import remove_null_values_in_dict_recursively -from .dict_utils import dict_merge +from .dict_utils import remove_null_values_in_dict_recursively, dict_merge +from .file_utils import load_file, load_files from .hash import str_to_hash from .import_utils import import_optional_dependency from .import_utils import is_dependency_installed @@ -43,4 +43,6 @@ def minify_json(value: str) -> str: "constants", "load_artifact_from_memory", "deprecation_warn", + "load_file", + "load_files", ] diff --git a/griptape/utils/file_utils.py b/griptape/utils/file_utils.py new file mode 100644 index 000000000..402436a2f --- /dev/null +++ b/griptape/utils/file_utils.py @@ -0,0 +1,35 @@ +import griptape.utils as utils +from concurrent import futures +from typing import Optional + + +def load_file(path: str) -> bytes: + """Load a file from the given path and return its content as bytes. + + Args: + path (str): The path to the file to load. + + Returns: + The content of the file. + """ + with open(path, "rb") as f: + return f.read() + + +def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolExecutor] = None) -> dict[str, bytes]: + """Load multiple files concurrently and return a dictionary of their content. + + Args: + paths: The paths to the files to load. + futures_executor: The executor to use for concurrent loading. If None, a new ThreadPoolExecutor will be created. + + Returns: + A dictionary where the keys are a hash of the path and the values are the content of the files. + """ + + if futures_executor is None: + futures_executor = futures.ThreadPoolExecutor() + + return utils.execute_futures_dict( + {utils.str_to_hash(str(path)): futures_executor.submit(load_file, path) for path in paths} + ) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 80d3ea5a1..a0251fa57 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -53,14 +53,11 @@ def run(self, *args) -> Iterator[TextArtifact]: t.join() def _run_structure(self, *args): - from griptape.drivers import LocalEventListenerDriver - def event_handler(event: BaseEvent): self._event_queue.put(event) stream_event_listener = EventListener( - driver=LocalEventListenerDriver(handler=event_handler), - event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], + handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent] ) self.structure.add_event_listener(stream_event_listener) diff --git a/mkdocs.yml b/mkdocs.yml index bb90117a3..f5f0dc954 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -74,6 +74,9 @@ nav: - Home: - Overview: "index.md" - Contributing: "contributing.md" + - Cloud: + - Cloud API: + - API Reference: "griptape-cloud/api/api-reference.md" - Framework: - Overview: "griptape-framework/index.md" - Structures: @@ -103,6 +106,7 @@ nav: - Image Query Drivers: "griptape-framework/drivers/image-query-drivers.md" - Web Scraper Drivers: "griptape-framework/drivers/web-scraper-drivers.md" - Conversation Memory Drivers: "griptape-framework/drivers/conversation-memory-drivers.md" + - Event Listener Drivers: "griptape-framework/drivers/event-listener-drivers.md" - Data: - Overview: "griptape-framework/data/index.md" - Artifacts: "griptape-framework/data/artifacts.md" @@ -125,6 +129,7 @@ nav: - GoogleGmailClient: "griptape-tools/official-tools/google-gmail-client.md" - GoogleDriveClient: "griptape-tools/official-tools/google-drive-client.md" - GoogleDocsClient: "griptape-tools/official-tools/google-docs-client.md" + - StructureRunClient: "griptape-tools/official-tools/structure-run-client.md" - OpenWeatherClient: "griptape-tools/official-tools/openweather-client.md" - RestApiClient: "griptape-tools/official-tools/rest-api-client.md" - SqlClient: "griptape-tools/official-tools/sql-client.md" @@ -145,6 +150,7 @@ nav: - Talk to Redshift: "examples/talk-to-redshift.md" - Talk to a Webpage: "examples/talk-to-a-webpage.md" - Talk to a PDF: "examples/talk-to-a-pdf.md" + - Multi Agent Workflows: "examples/multi-agent-workflow.md" - Shared Memory Between Agents: "examples/multiple-agent-shared-memory.md" - Chat Sessions with Amazon DynamoDB: "examples/amazon-dynamodb-sessions.md" - Data: diff --git a/poetry.lock b/poetry.lock index ac794bc1e..37908cc4f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2168,24 +2168,6 @@ files = [ [package.extras] data = ["language-data (>=1.1,<2.0)"] -[[package]] -name = "loguru" -version = "0.7.2" -description = "Python logging made (stupidly) simple" -optional = true -python-versions = ">=3.5" -files = [ - {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, - {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, -] - -[package.dependencies] -colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} -win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} - -[package.extras] -dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] - [[package]] name = "lxml" version = "4.9.4" @@ -3287,28 +3269,26 @@ xmp = ["defusedxml"] [[package]] name = "pinecone-client" -version = "2.2.4" +version = "3.2.2" description = "Pinecone client and SDK" optional = true -python-versions = ">=3.8" +python-versions = "<4.0,>=3.8" files = [ - {file = "pinecone-client-2.2.4.tar.gz", hash = "sha256:2c1cc1d6648b2be66e944db2ffa59166a37b9164d1135ad525d9cd8b1e298168"}, - {file = "pinecone_client-2.2.4-py3-none-any.whl", hash = "sha256:5bf496c01c2f82f4e5c2dc977cc5062ecd7168b8ed90743b09afcc8c7eb242ec"}, + {file = "pinecone_client-3.2.2-py3-none-any.whl", hash = "sha256:7e492fdda23c73726bc0cb94c689bb950d06fb94e82b701a0c610c2e830db327"}, + {file = "pinecone_client-3.2.2.tar.gz", hash = "sha256:887a12405f90ac11c396490f605fc479f31cf282361034d1ae0fccc02ac75bee"}, ] [package.dependencies] -dnspython = ">=2.0.0" -loguru = ">=0.5.0" -numpy = ">=1.22.0" -python-dateutil = ">=2.5.3" -pyyaml = ">=5.4" -requests = ">=2.19.0" +certifi = ">=2019.11.17" tqdm = ">=4.64.1" typing-extensions = ">=3.7.4" -urllib3 = ">=1.21.1" +urllib3 = [ + {version = ">=1.26.0", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26.5", markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, +] [package.extras] -grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"] +grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "grpcio (>=1.59.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"] [[package]] name = "pkginfo" @@ -3487,7 +3467,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -3496,8 +3475,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -5573,20 +5550,6 @@ MarkupSafe = ">=2.1.1" [package.extras] watchdog = ["watchdog (>=2.3)"] -[[package]] -name = "win32-setctime" -version = "1.1.0" -description = "A small Python utility to set file creation time on Windows" -optional = true -python-versions = ">=3.5" -files = [ - {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, - {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, -] - -[package.extras] -dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] - [[package]] name = "xmltodict" version = "0.13.0" @@ -5752,4 +5715,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "106d17537887f67f16d71cb2d5e0a1c3c0c76ab91102cc55c79fccbea757b6b4" +content-hash = "7aa1485db323176c7b372efd3483d060c469d18fdf0c6ed172bb3a82d4ab238b" diff --git a/pyproject.toml b/pyproject.toml index caea403f4..fbe331b05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "griptape" -version = "0.24.2" +version = "0.25.0" description = "Modular Python framework for LLM workflows, tools, memory, and data." authors = ["Griptape "] license = "Apache 2.0" @@ -38,7 +38,7 @@ huggingface-hub = { version = ">=0.13", optional = true } boto3 = { version = "^1.28.2", optional = true } sqlalchemy-redshift = { version = "*", optional = true } snowflake-sqlalchemy = { version = "^1.4.7", optional = true } -pinecone-client = { version = "^2", optional = true } +pinecone-client = { version = "^3", optional = true } pymongo = { version = "*", optional = true } marqo = { version = ">=1.1.0", optional = true } redis = { version = "^4.6.0", optional = true } diff --git a/tests/mocks/mock_event.py b/tests/mocks/mock_event.py index 651cf3ece..2b9d9ade3 100644 --- a/tests/mocks/mock_event.py +++ b/tests/mocks/mock_event.py @@ -3,4 +3,4 @@ class MockEvent(BaseEvent): def to_dict(self) -> dict: - return {"timestamp": self.timestamp} + return {"timestamp": self.timestamp, "id": self.id} diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py new file mode 100644 index 000000000..dd54eeb73 --- /dev/null +++ b/tests/mocks/mock_event_listener_driver.py @@ -0,0 +1,12 @@ +from attr import define + +from griptape.drivers import BaseEventListenerDriver + + +@define +class MockEventListenerDriver(BaseEventListenerDriver): + def try_publish_event_payload(self, event_payload: dict) -> None: + pass + + def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: + pass diff --git a/tests/mocks/mock_multi_text_input_task.py b/tests/mocks/mock_multi_text_input_task.py new file mode 100644 index 000000000..1da645ee4 --- /dev/null +++ b/tests/mocks/mock_multi_text_input_task.py @@ -0,0 +1,9 @@ +from attr import define +from griptape.artifacts import TextArtifact +from griptape.tasks import BaseMultiTextInputTask + + +@define +class MockMultiTextInputTask(BaseMultiTextInputTask): + def run(self) -> TextArtifact: + return TextArtifact(self.input[0].to_text()) diff --git a/tests/resources/foobar-many.txt b/tests/resources/foobar-many.txt new file mode 100644 index 000000000..00fbc1398 --- /dev/null +++ b/tests/resources/foobar-many.txt @@ -0,0 +1,26 @@ +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + +foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar foobar + diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index 8c3288cfe..2321665ae 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -19,7 +19,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -73,7 +73,7 @@ def test_to_dict(self, config): "prompt_driver": { "base_url": None, "type": "OpenAiChatPromptDriver", - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -99,7 +99,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -114,7 +114,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -130,7 +130,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, diff --git a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py index 4a19011be..706831d67 100644 --- a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py @@ -27,5 +27,8 @@ def driver(self): def test_init(self, driver): assert driver - def test_try_publish_event(self, driver): - driver.try_publish_event(event=MockEvent()) + def test_try_publish_event_payload(self, driver): + driver.try_publish_event_payload(MockEvent().to_dict()) + + def test_try_publish_event_payload_batch(self, driver): + driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)]) diff --git a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py index c8b6a77ed..9a5fe9ec0 100644 --- a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py @@ -21,5 +21,8 @@ def driver(self): def test_init(self, driver): assert driver - def test_try_publish_event(self, driver): - driver.try_publish_event(event=MockEvent()) + def test_try_publish_event_payload(self, driver): + driver.try_publish_event_payload(MockEvent().to_dict()) + + def test_try_publish_event_payload_batch(self, driver): + driver.try_publish_event_payload_batch([MockEvent().to_dict() for _ in range(3)]) diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py new file mode 100644 index 000000000..6d33dd2a0 --- /dev/null +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -0,0 +1,25 @@ +from tests.mocks.mock_event import MockEvent +from tests.mocks.mock_event_listener_driver import MockEventListenerDriver + + +class TestBaseEventListenerDriver: + def test__safe_try_publish_event(self): + driver = MockEventListenerDriver(batched=False) + + for _ in range(4): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) + assert len(driver.batch) == 0 + + def test__safe_try_publish_event_batch(self): + driver = MockEventListenerDriver(batched=True) + + for _ in range(0, 3): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=False) + assert len(driver.batch) == 3 + + def test__safe_try_publish_event_batch_flush(self): + driver = MockEventListenerDriver(batched=True) + + for _ in range(0, 3): + driver._safe_try_publish_event(MockEvent().to_dict(), flush=True) + assert len(driver.batch) == 0 diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index c459fc8a2..d27f09ec8 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -1,8 +1,11 @@ +import os from unittest.mock import Mock -from pytest import fixture + import pytest -from tests.mocks.mock_event import MockEvent +from pytest import fixture + from griptape.drivers.event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver +from tests.mocks.mock_event import MockEvent class TestGriptapeCloudEventListenerDriver: @@ -17,21 +20,36 @@ def mock_post(self, mocker): @fixture() def driver(self): - return GriptapeCloudEventListenerDriver(api_key="foo bar", run_id="baz") + os.environ["GT_CLOUD_BASE_URL"] = "https://cloud123.griptape.ai" + + return GriptapeCloudEventListenerDriver(api_key="foo bar", structure_run_id="bar baz") def test_init(self, driver): assert driver + assert driver.api_key == "foo bar" + assert driver.structure_run_id == "bar baz" - def test_try_publish_event(self, mock_post, driver): + def test_try_publish_event_payload(self, mock_post, driver): event = MockEvent() - driver.try_publish_event(event=event) + driver.try_publish_event_payload(event.to_dict()) mock_post.assert_called_once_with( - url=f"https://cloud.griptape.ai/api/structure-runs/{driver.run_id}/events", + url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"}, ) - def test_no_run_id(self): + def try_publish_event_payload_batch(self, mock_post, driver): + for _ in range(3): + event = MockEvent() + driver.try_publish_event_payload(event.to_dict()) + + mock_post.assert_called_with( + url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", + json=event.to_dict(), + headers={"Authorization": "Bearer foo bar"}, + ) + + def test_no_structure_run_id(self): with pytest.raises(ValueError): GriptapeCloudEventListenerDriver(api_key="foo bar") diff --git a/tests/unit/drivers/event_listener/test_local_event_listener_driver.py b/tests/unit/drivers/event_listener/test_local_event_listener_driver.py deleted file mode 100644 index a92276fa8..000000000 --- a/tests/unit/drivers/event_listener/test_local_event_listener_driver.py +++ /dev/null @@ -1,14 +0,0 @@ -from moto import mock_iotdata -from unittest.mock import Mock -from tests.mocks.mock_event import MockEvent -from griptape.drivers.event_listener.local_event_listener_driver import LocalEventListenerDriver - - -@mock_iotdata -class TestLocalEventListenerDriver: - def test_try_publish_event(self): - mock = Mock() - event = MockEvent() - driver = LocalEventListenerDriver(handler=mock) - driver.try_publish_event(event=event) - mock.assert_called_once_with(event) diff --git a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py index 207fd48cf..50021cbe3 100644 --- a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py @@ -15,11 +15,22 @@ def mock_post(self, mocker): def test_init(self): assert WebhookEventListenerDriver(webhook_url="") - def test_try_publish_event(self, mock_post): + def test_try_publish_event_payload(self, mock_post): driver = WebhookEventListenerDriver(webhook_url="foo bar", headers={"Authorization": "Bearer foo bar"}) event = MockEvent() - driver.try_publish_event(event=event) + driver.try_publish_event_payload(event.to_dict()) mock_post.assert_called_once_with( - url="foo bar", json={"event": event.to_dict()}, headers={"Authorization": "Bearer foo bar"} + url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"} ) + + def test_try_publish_event_payload_batch(self, mock_post): + driver = WebhookEventListenerDriver(webhook_url="foo bar", headers={"Authorization": "Bearer foo bar"}) + + for _ in range(3): + event = MockEvent() + driver.try_publish_event_payload(event.to_dict()) + + mock_post.assert_called_with( + url="foo bar", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"} + ) diff --git a/tests/unit/drivers/structure_run/__init__.py b/tests/unit/drivers/structure_run/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py new file mode 100644 index 000000000..8318553b1 --- /dev/null +++ b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py @@ -0,0 +1,35 @@ +import pytest +from griptape.artifacts import TextArtifact, InfoArtifact + + +class TestGriptapeCloudStructureRunDriver: + @pytest.fixture + def driver(self, mocker): + from griptape.drivers import GriptapeCloudStructureRunDriver + + mock_response = mocker.Mock() + mock_response.json.return_value = {"structure_run_id": 1} + mocker.patch("requests.post", return_value=mock_response) + + mock_response = mocker.Mock() + mock_response.json.return_value = { + "description": "fizz buzz", + "output": TextArtifact("foo bar").to_dict(), + "status": "SUCCEEDED", + } + mocker.patch("requests.get", return_value=mock_response) + + return GriptapeCloudStructureRunDriver( + base_url="https://cloud-foo.griptape.ai", api_key="foo bar", structure_id="1" + ) + + def test_run(self, driver): + result = driver.run(TextArtifact("foo bar")) + assert isinstance(result, TextArtifact) + assert result.value == "foo bar" + + def test_async_run(self, driver): + driver.async_run = True + result = driver.run(TextArtifact("foo bar")) + assert isinstance(result, InfoArtifact) + assert result.value == "Run started successfully" diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py new file mode 100644 index 000000000..04da4f2cf --- /dev/null +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -0,0 +1,24 @@ +import pytest +from griptape.tasks import StructureRunTask +from griptape.structures import Agent +from tests.mocks.mock_prompt_driver import MockPromptDriver +from griptape.drivers import LocalStructureRunDriver +from griptape.structures import Pipeline + + +class TestLocalStructureRunDriver: + @pytest.fixture + def driver(self): + agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) + driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) + + return driver + + def test_run(self, driver): + pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + + task = StructureRunTask(driver=driver) + + pipeline.add_task(task) + + assert task.run().to_text() == "agent mock output" diff --git a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py index 993821de9..defbe8e3b 100644 --- a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py +++ b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py @@ -18,10 +18,10 @@ def mock_pinecone(self, mocker): "namespace": "foobar", } - mocker.patch("pinecone.init", return_value=None) - mocker.patch("pinecone.Index.upsert", return_value=None) - mocker.patch("pinecone.Index.query", return_value=fake_query_response) - mocker.patch("pinecone.create_index", return_value=None) + mock_client = mocker.patch("pinecone.Pinecone") + mock_client().Index().upsert.return_value = None + mock_client().Index().query.return_value = fake_query_response + mock_client().create_index.return_value = None @pytest.fixture def driver(self): diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index 2def78015..7656b6b0d 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -177,6 +177,7 @@ def test_start_structure_run_event_from_dict(self): dict_value = { "type": "StartStructureRunEvent", "timestamp": 123.0, + "structure_id": "foo", "input_task_input": {"type": "TextArtifact", "value": "foo"}, "input_task_output": {"type": "TextArtifact", "value": "bar"}, } @@ -193,6 +194,7 @@ def test_finish_structure_run_event_from_dict(self): dict_value = { "type": "FinishStructureRunEvent", "timestamp": 123.0, + "structure_id": "foo", "output_task_input": {"type": "TextArtifact", "value": "foo"}, "output_task_output": {"type": "TextArtifact", "value": "bar"}, } diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 39b59ea94..2f32837e0 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -1,6 +1,6 @@ from unittest.mock import Mock import pytest -from griptape.drivers.event_listener.local_event_listener_driver import LocalEventListenerDriver +from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline from griptape.tasks import ToolkitTask, ActionsSubtask from griptape.events import ( @@ -17,6 +17,7 @@ ) from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool +from tests.mocks.mock_event import MockEvent class TestEventListener: @@ -34,10 +35,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [ - EventListener(driver=LocalEventListenerDriver(handler=event_handler_1)), - EventListener(driver=LocalEventListenerDriver(handler=event_handler_2)), - ] + pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() pipeline.tasks[0].subtasks[0].after_run() @@ -58,37 +56,15 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler = Mock() pipeline.event_listeners = [ - EventListener( - driver=LocalEventListenerDriver(handler=start_prompt_event_handler), event_types=[StartPromptEvent] - ), - EventListener( - driver=LocalEventListenerDriver(handler=finish_prompt_event_handler), event_types=[FinishPromptEvent] - ), - EventListener( - driver=LocalEventListenerDriver(handler=start_task_event_handler), event_types=[StartTaskEvent] - ), - EventListener( - driver=LocalEventListenerDriver(handler=finish_task_event_handler), event_types=[FinishTaskEvent] - ), - EventListener( - driver=LocalEventListenerDriver(handler=start_subtask_event_handler), - event_types=[StartActionsSubtaskEvent], - ), - EventListener( - driver=LocalEventListenerDriver(handler=finish_subtask_event_handler), - event_types=[FinishActionsSubtaskEvent], - ), - EventListener( - driver=LocalEventListenerDriver(handler=start_structure_run_event_handler), - event_types=[StartStructureRunEvent], - ), - EventListener( - driver=LocalEventListenerDriver(handler=finish_structure_run_event_handler), - event_types=[FinishStructureRunEvent], - ), - EventListener( - driver=LocalEventListenerDriver(handler=completion_chunk_handler), event_types=[CompletionChunkEvent] - ), + EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), + EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), + EventListener(start_task_event_handler, event_types=[StartTaskEvent]), + EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), + EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), + EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), + EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), + EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), + EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), ] # can't mock subtask events, so must manually call @@ -110,23 +86,45 @@ def test_add_remove_event_listener(self, pipeline): pipeline.event_listeners = [] mock1 = Mock() mock2 = Mock() - event_listener_1 = pipeline.add_event_listener( - EventListener(driver=LocalEventListenerDriver(handler=mock1), event_types=[StartPromptEvent]) - ) + # duplicate event listeners will only get added once + event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_2 = pipeline.add_event_listener( - EventListener(driver=LocalEventListenerDriver(handler=mock1), event_types=[FinishPromptEvent]) - ) - event_listener_3 = pipeline.add_event_listener( - EventListener(driver=LocalEventListenerDriver(handler=mock2), event_types=[StartPromptEvent]) - ) + event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(driver=LocalEventListenerDriver(handler=mock2))) + event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) assert len(pipeline.event_listeners) == 4 pipeline.remove_event_listener(event_listener_1) - pipeline.remove_event_listener(event_listener_2) pipeline.remove_event_listener(event_listener_3) pipeline.remove_event_listener(event_listener_4) + pipeline.remove_event_listener(event_listener_5) assert len(pipeline.event_listeners) == 0 + + def test_publish_event(self): + mock_event_listener_driver = Mock() + mock_event_listener_driver.try_publish_event_payload.return_value = None + + def event_handler(_: BaseEvent): + return None + + mock_event = MockEvent() + event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) + event_listener.publish_event(mock_event) + + mock_event_listener_driver.publish_event.assert_called_once_with(mock_event, flush=False) + + def test_publish_transformed_event(self): + mock_event_listener_driver = Mock() + mock_event_listener_driver.publish_event.return_value = None + + def event_handler(event: BaseEvent): + return {"event": event.to_dict()} + + mock_event = MockEvent() + event_listener = EventListener(event_handler, driver=mock_event_listener_driver, event_types=[MockEvent]) + event_listener.publish_event(mock_event) + + mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False) diff --git a/tests/unit/events/test_finish_structure_run_event.py b/tests/unit/events/test_finish_structure_run_event.py index d369ab5e5..68ad1ea01 100644 --- a/tests/unit/events/test_finish_structure_run_event.py +++ b/tests/unit/events/test_finish_structure_run_event.py @@ -6,10 +6,13 @@ class TestFinishStructureRunEvent: @pytest.fixture def finish_structure_run_event(self): - return FinishStructureRunEvent(output_task_input=TextArtifact("foo"), output_task_output=TextArtifact("bar")) + return FinishStructureRunEvent( + structure_id="fizz", output_task_input=TextArtifact("foo"), output_task_output=TextArtifact("bar") + ) def test_to_dict(self, finish_structure_run_event): assert finish_structure_run_event.to_dict() is not None + assert finish_structure_run_event.to_dict()["structure_id"] == "fizz" assert finish_structure_run_event.to_dict()["output_task_input"]["value"] == "foo" assert finish_structure_run_event.to_dict()["output_task_output"]["value"] == "bar" diff --git a/tests/unit/events/test_start_structure_run_event.py b/tests/unit/events/test_start_structure_run_event.py index 945b38e64..c2f1b923d 100644 --- a/tests/unit/events/test_start_structure_run_event.py +++ b/tests/unit/events/test_start_structure_run_event.py @@ -6,9 +6,12 @@ class TestStartStructureRunEvent: @pytest.fixture def start_structure_run_event(self): - return StartStructureRunEvent(input_task_input=TextArtifact("foo"), input_task_output=TextArtifact("bar")) + return StartStructureRunEvent( + structure_id="fizz", input_task_input=TextArtifact("foo"), input_task_output=TextArtifact("bar") + ) def test_to_dict(self, start_structure_run_event): assert start_structure_run_event.to_dict() is not None + assert start_structure_run_event.to_dict()["structure_id"] == "fizz" assert start_structure_run_event.to_dict()["input_task_input"]["value"] == "foo" assert start_structure_run_event.to_dict()["input_task_output"]["value"] == "bar" diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py new file mode 100644 index 000000000..542162757 --- /dev/null +++ b/tests/unit/tasks/test_base_multi_text_input_task.py @@ -0,0 +1,58 @@ +from tests.mocks.mock_prompt_driver import MockPromptDriver +from griptape.structures import Pipeline +from griptape.artifacts import TextArtifact +from griptape.rules import Ruleset, Rule +from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask + + +class TestBaseMultiTextInputTask: + def test_string_input(self): + assert MockMultiTextInputTask(("foobar", "bazbar")).input[0].value == "foobar" + assert MockMultiTextInputTask(("foobar", "bazbar")).input[1].value == "bazbar" + + task = MockMultiTextInputTask() + task.input = ("foobar", "bazbar") + assert task.input[0].value == "foobar" + assert task.input[1].value == "bazbar" + + def test_artifact_input(self): + assert MockMultiTextInputTask((TextArtifact("foobar"), TextArtifact("bazbar"))).input[0].value == "foobar" + assert MockMultiTextInputTask((TextArtifact("foobar"), TextArtifact("bazbar"))).input[1].value == "bazbar" + + task = MockMultiTextInputTask() + task.input = (TextArtifact("foobar"), TextArtifact("bazbar")) + assert task.input[0].value == "foobar" + assert task.input[1].value == "bazbar" + + def test_callable_input(self): + assert ( + MockMultiTextInputTask((lambda _: TextArtifact("foobar"), lambda _: TextArtifact("bazbar"))).input[0].value + == "foobar" + ) + assert ( + MockMultiTextInputTask((lambda _: TextArtifact("foobar"), lambda _: TextArtifact("bazbar"))).input[1].value + == "bazbar" + ) + + task = MockMultiTextInputTask() + task.input = (lambda _: TextArtifact("foobar"), lambda _: TextArtifact("bazbar")) + assert task.input[0].value == "foobar" + assert task.input[1].value == "bazbar" + + def test_full_context(self): + parent = MockMultiTextInputTask(("parent1", "parent2")) + subtask = MockMultiTextInputTask(("test1", "test2"), context={"foo": "bar"}) + child = MockMultiTextInputTask(("child2", "child2")) + pipeline = Pipeline(prompt_driver=MockPromptDriver()) + + pipeline.add_tasks(parent, subtask, child) + + pipeline.run() + + context = subtask.full_context + + assert context["foo"] == "bar" + assert context["parent_output"] == parent.output.to_text() + assert context["structure"] == pipeline + assert context["parent"] == parent + assert context["child"] == child diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py new file mode 100644 index 000000000..d89e98c91 --- /dev/null +++ b/tests/unit/tasks/test_structure_run_task.py @@ -0,0 +1,18 @@ +from griptape.tasks import StructureRunTask +from griptape.structures import Agent +from tests.mocks.mock_prompt_driver import MockPromptDriver +from griptape.drivers import LocalStructureRunDriver +from griptape.structures import Pipeline + + +class TestStructureRunTask: + def test_run(self): + agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) + pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) + + task = StructureRunTask(driver=driver) + + pipeline.add_task(task) + + assert task.run().to_text() == "agent mock output" diff --git a/tests/unit/tools/test_structure_run_client.py b/tests/unit/tools/test_structure_run_client.py new file mode 100644 index 000000000..b57bfb28f --- /dev/null +++ b/tests/unit/tools/test_structure_run_client.py @@ -0,0 +1,20 @@ +import pytest +from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver +from griptape.tools import StructureRunClient +from griptape.structures import Agent +from tests.mocks.mock_prompt_driver import MockPromptDriver + + +class TestStructureRunClient: + @pytest.fixture + def client(self): + driver = MockPromptDriver() + agent = Agent(prompt_driver=driver) + + return StructureRunClient( + description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent) + ) + + def test_run_structure(self, client): + assert client.run_structure({"values": {"args": "foo bar"}}).value == "mock output" + assert client.description == "foo bar" diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py new file mode 100644 index 000000000..dbcf1044b --- /dev/null +++ b/tests/unit/utils/test_file_utils.py @@ -0,0 +1,53 @@ +import os +from griptape.loaders import TextLoader +from griptape import utils +from concurrent import futures +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver + +MAX_TOKENS = 50 + + +class TestFileUtils: + def test_load_file(self): + dirname = os.path.dirname(__file__) + file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt")) + + assert file.decode("utf-8").startswith("foobar foobar foobar") + assert len(file.decode("utf-8")) == 4563 + + def test_load_files(self): + dirname = os.path.dirname(__file__) + sources = ["resources/foobar-many.txt", "resources/foobar-many.txt", "resources/small.png"] + sources = [os.path.join(dirname, "../../", source) for source in sources] + files = utils.load_files(sources, futures_executor=futures.ThreadPoolExecutor(max_workers=1)) + assert len(files) == 2 + + test_file = files[utils.str_to_hash(sources[0])] + assert len(test_file) == 4563 + assert test_file.decode("utf-8").startswith("foobar foobar foobar") + + small_file = files[utils.str_to_hash(sources[2])] + assert len(small_file) == 97 + assert small_file[:8] == b"\x89PNG\r\n\x1a\n" + + def test_load_file_with_loader(self): + dirname = os.path.dirname(__file__) + file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) + artifacts = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()).load(file) + + assert len(artifacts) == 39 + assert isinstance(artifacts, list) + assert artifacts[0].value.startswith("foobar foobar foobar") + + def test_load_files_with_loader(self): + dirname = os.path.dirname(__file__) + sources = ["resources/foobar-many.txt"] + sources = [os.path.join(dirname, "../../", source) for source in sources] + files = utils.load_files(sources) + loader = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + collection = loader.load_collection(list(files.values())) + + test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] + assert len(test_file_artifacts) == 39 + assert isinstance(test_file_artifacts, list) + assert test_file_artifacts[0].value.startswith("foobar foobar foobar")