From 943ce723b4227778c732e5e0dbf6d405f092a3fc Mon Sep 17 00:00:00 2001 From: Matt Vallillo Date: Fri, 31 May 2024 09:09:07 -0700 Subject: [PATCH 1/4] Use environment variables for structure runs (#808) --- CHANGELOG.md | 1 + .../base_structure_run_driver.py | 10 +++++++- .../griptape_cloud_structure_run_driver.py | 4 +++- .../local_structure_run_driver.py | 10 +++++++- tests/mocks/mock_prompt_driver.py | 7 +++--- ...est_griptape_cloud_structure_run_driver.py | 2 +- .../test_local_structure_run_driver.py | 23 ++++++++++++------- 7 files changed, 42 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb2e5636a..8dee34445 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration. - `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models. +- Parameter `env` to `BaseStructureRunDriver` to set environment variables for a Structure Run. ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. diff --git a/griptape/drivers/structure_run/base_structure_run_driver.py b/griptape/drivers/structure_run/base_structure_run_driver.py index 5c188a18d..4ff9b6eb2 100644 --- a/griptape/drivers/structure_run/base_structure_run_driver.py +++ b/griptape/drivers/structure_run/base_structure_run_driver.py @@ -1,12 +1,20 @@ from abc import ABC, abstractmethod -from attrs import define +from attrs import define, Factory, field from griptape.artifacts import BaseArtifact @define class BaseStructureRunDriver(ABC): + """Base class for Structure Run Drivers. + + Attributes: + env: Environment variables to set before running the Structure. + """ + + env: dict[str, str] = field(default=Factory(dict), kw_only=True) + def run(self, *args: BaseArtifact) -> BaseArtifact: return self.try_run(*args) diff --git a/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py index 9ed036995..40e6ff874 100644 --- a/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py +++ b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py @@ -28,7 +28,9 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact: url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs") try: - response: Response = post(url, json={"args": [arg.value for arg in args]}, headers=self.headers) + response: Response = post( + url, json={"args": [arg.value for arg in args], "env": self.env}, headers=self.headers + ) response.raise_for_status() response_json = response.json() diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py index 255f91445..4f140c1fe 100644 --- a/griptape/drivers/structure_run/local_structure_run_driver.py +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -6,6 +6,8 @@ from griptape.artifacts import BaseArtifact, InfoArtifact from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver +import os + if TYPE_CHECKING: from griptape.structures import Structure @@ -15,7 +17,13 @@ class LocalStructureRunDriver(BaseStructureRunDriver): structure_factory_fn: Callable[[], Structure] = field(kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact: - structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) + old_env = os.environ.copy() + try: + os.environ.update(self.env) + structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) + finally: + os.environ.clear() + os.environ.update(old_env) if structure_factory_fn.output_task.output is not None: return structure_factory_fn.output_task.output diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index ca1f67f5f..e2018c6f6 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,5 +1,6 @@ from collections.abc import Iterator -from attr import define, field +from typing import Callable +from attr import Factory, define, field from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer @@ -11,10 +12,10 @@ class MockPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) - mock_output: str = field(default="mock output", kw_only=True) + mock_output: str | Callable[[], str] = field(default="mock output", kw_only=True) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - return TextArtifact(value=self.mock_output) + return TextArtifact(value=self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output) def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: yield TextArtifact(value=self.mock_output) diff --git a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py index 8318553b1..b056241ec 100644 --- a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py @@ -20,7 +20,7 @@ def driver(self, mocker): mocker.patch("requests.get", return_value=mock_response) return GriptapeCloudStructureRunDriver( - base_url="https://cloud-foo.griptape.ai", api_key="foo bar", structure_id="1" + base_url="https://cloud-foo.griptape.ai", api_key="foo bar", structure_id="1", env={"key": "value"} ) def test_run(self, driver): diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 04da4f2cf..921a4b013 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -1,4 +1,6 @@ +import os import pytest +from griptape.artifacts.text_artifact import TextArtifact from griptape.tasks import StructureRunTask from griptape.structures import Agent from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -7,18 +9,23 @@ class TestLocalStructureRunDriver: - @pytest.fixture - def driver(self): - agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) + def test_run(self): + pipeline = Pipeline() + driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent(prompt_driver=MockPromptDriver())) - return driver + task = StructureRunTask(driver=driver) + + pipeline.add_task(task) + + assert task.run().to_text() == "mock output" - def test_run(self, driver): - pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + def test_run_with_env(self): + pipeline = Pipeline() + agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda: os.environ["key"])) + driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"key": "value"}) task = StructureRunTask(driver=driver) pipeline.add_task(task) - assert task.run().to_text() == "agent mock output" + assert task.run().to_text() == "value" From 427df75d61d8862c043b6b1045a1cbdc85aaf5af Mon Sep 17 00:00:00 2001 From: Matt Vallillo Date: Fri, 31 May 2024 10:43:40 -0700 Subject: [PATCH 2/4] chore: update `attr` to `attrs` (#810) Co-authored-by: Collin Dutter --- griptape/artifacts/audio_artifact.py | 2 +- griptape/artifacts/base_artifact.py | 2 +- griptape/artifacts/blob_artifact.py | 2 +- griptape/artifacts/csv_row_artifact.py | 2 +- griptape/artifacts/error_artifact.py | 2 +- griptape/artifacts/image_artifact.py | 2 +- griptape/artifacts/info_artifact.py | 2 +- griptape/artifacts/list_artifact.py | 2 +- griptape/artifacts/media_artifact.py | 2 +- griptape/artifacts/text_artifact.py | 2 +- griptape/chunkers/base_chunker.py | 2 +- griptape/config/base_structure_config.py | 2 +- .../amazon_bedrock_cohere_embedding_driver.py | 2 +- .../amazon_bedrock_titan_embedding_driver.py | 2 +- .../amazon_sagemaker_embedding_driver.py | 2 +- .../azure_openai_embedding_driver.py | 2 +- .../embedding/base_embedding_driver.py | 2 +- .../base_multi_model_embedding_driver.py | 2 +- .../embedding/google_embedding_driver.py | 2 +- .../huggingface_hub_embedding_driver.py | 2 +- .../embedding/openai_embedding_driver.py | 2 +- .../embedding/voyageai_embedding_driver.py | 2 +- .../base_embedding_model_driver.py | 2 +- ...aker_huggingface_embedding_model_driver.py | 2 +- ...r_tensorflow_hub_embedding_model_driver.py | 2 +- .../amazon_sqs_event_listener_driver.py | 2 +- .../aws_iot_core_event_listener_driver.py | 2 +- .../base_event_listener_driver.py | 2 +- .../griptape_cloud_event_listener_driver.py | 2 +- .../webhook_event_listener_driver.py | 2 +- .../amazon_s3_file_manager_driver.py | 2 +- .../file_manager/base_file_manager_driver.py | 2 +- .../file_manager/local_file_manager_driver.py | 2 +- .../amazon_bedrock_image_generation_driver.py | 2 +- .../azure_openai_image_generation_driver.py | 2 +- .../base_image_generation_driver.py | 2 +- ...ase_multi_model_image_generation_driver.py | 2 +- .../leonardo_image_generation_driver.py | 2 +- .../openai_image_generation_driver.py | 2 +- .../base_image_generation_model_driver.py | 2 +- ...diffusion_image_generation_model_driver.py | 2 +- ...ock_titan_image_generation_model_driver.py | 2 +- .../amazon_bedrock_image_query_driver.py | 2 +- .../anthropic_image_query_driver.py | 2 +- .../azure_openai_image_query_driver.py | 2 +- .../image_query/base_image_query_driver.py | 2 +- .../base_multi_model_image_query_driver.py | 2 +- .../image_query/dummy_image_query_driver.py | 2 +- .../image_query/openai_image_query_driver.py | 2 +- .../base_image_query_model_driver.py | 2 +- ...bedrock_claude_image_query_model_driver.py | 2 +- ...zon_dynamodb_conversation_memory_driver.py | 2 +- .../local_conversation_memory_driver.py | 2 +- .../redis_conversation_memory_driver.py | 2 +- .../prompt/amazon_bedrock_prompt_driver.py | 2 +- .../prompt/amazon_sagemaker_prompt_driver.py | 2 +- .../drivers/prompt/anthropic_prompt_driver.py | 2 +- .../prompt/azure_openai_chat_prompt_driver.py | 2 +- .../azure_openai_completion_prompt_driver.py | 2 +- griptape/drivers/prompt/base_prompt_driver.py | 2 +- .../drivers/prompt/cohere_prompt_driver.py | 2 +- .../drivers/prompt/google_prompt_driver.py | 2 +- .../prompt/huggingface_hub_prompt_driver.py | 2 +- .../huggingface_pipeline_prompt_driver.py | 2 +- .../prompt/openai_chat_prompt_driver.py | 2 +- .../prompt/openai_completion_prompt_driver.py | 2 +- .../prompt_model/base_prompt_model_driver.py | 2 +- .../bedrock_claude_prompt_model_driver.py | 2 +- .../bedrock_jurassic_prompt_model_driver.py | 2 +- .../bedrock_llama_prompt_model_driver.py | 2 +- .../bedrock_titan_prompt_model_driver.py | 2 +- .../sagemaker_falcon_prompt_model_driver.py | 2 +- .../sagemaker_llama_prompt_model_driver.py | 2 +- .../drivers/sql/amazon_redshift_sql_driver.py | 2 +- griptape/drivers/sql/base_sql_driver.py | 2 +- griptape/drivers/sql/snowflake_sql_driver.py | 2 +- griptape/drivers/sql/sql_driver.py | 2 +- .../base_text_to_speech_driver.py | 2 +- .../elevenlabs_text_to_speech_driver.py | 2 +- .../openai_text_to_speech_driver.py | 2 +- .../amazon_opensearch_vector_store_driver.py | 2 +- .../azure_mongodb_vector_store_driver.py | 2 +- .../vector/base_vector_store_driver.py | 2 +- .../vector/local_vector_store_driver.py | 2 +- .../vector/marqo_vector_store_driver.py | 2 +- .../mongodb_atlas_vector_store_driver.py | 2 +- .../vector/opensearch_vector_store_driver.py | 2 +- .../vector/pgvector_vector_store_driver.py | 2 +- .../vector/pinecone_vector_store_driver.py | 2 +- .../vector/redis_vector_store_driver.py | 2 +- .../markdownify_web_scraper_driver.py | 2 +- .../trafilatura_web_scraper_driver.py | 2 +- .../engines/audio/text_to_speech_engine.py | 2 +- .../extraction/base_extraction_engine.py | 2 +- .../extraction/csv_extraction_engine.py | 2 +- .../extraction/json_extraction_engine.py | 2 +- .../image/base_image_generation_engine.py | 2 +- .../inpainting_image_generation_engine.py | 2 +- .../outpainting_image_generation_engine.py | 2 +- .../image/prompt_image_generation_engine.py | 2 +- .../variation_image_generation_engine.py | 2 +- .../engines/image_query/image_query_engine.py | 2 +- griptape/engines/query/base_query_engine.py | 2 +- griptape/engines/query/vector_query_engine.py | 2 +- .../engines/summary/base_summary_engine.py | 2 +- .../engines/summary/prompt_summary_engine.py | 2 +- griptape/events/base_event.py | 2 +- griptape/events/base_image_query_event.py | 2 +- griptape/events/completion_chunk_event.py | 2 +- griptape/events/finish_image_query_event.py | 2 +- .../events/start_image_generation_event.py | 2 +- griptape/events/start_image_query_event.py | 2 +- griptape/events/start_text_to_speech_event.py | 2 +- griptape/loaders/base_loader.py | 2 +- griptape/loaders/blob_loader.py | 2 +- griptape/loaders/csv_loader.py | 2 +- griptape/loaders/dataframe_loader.py | 2 +- griptape/loaders/email_loader.py | 2 +- griptape/loaders/image_loader.py | 2 +- griptape/loaders/pdf_loader.py | 2 +- griptape/loaders/sql_loader.py | 2 +- griptape/loaders/text_loader.py | 2 +- griptape/loaders/web_loader.py | 2 +- .../memory/meta/action_subtask_meta_entry.py | 2 +- griptape/memory/meta/base_meta_entry.py | 2 +- griptape/memory/meta/meta_memory.py | 2 +- .../structure/base_conversation_memory.py | 2 +- .../memory/structure/conversation_memory.py | 2 +- griptape/memory/structure/run.py | 2 +- .../structure/summary_conversation_memory.py | 2 +- .../task/storage/base_artifact_storage.py | 2 +- .../task/storage/blob_artifact_storage.py | 2 +- .../task/storage/text_artifact_storage.py | 2 +- griptape/memory/task/task_memory.py | 2 +- .../mixins/actions_subtask_origin_mixin.py | 2 +- griptape/mixins/activity_mixin.py | 2 +- griptape/mixins/exponential_backoff_mixin.py | 6 +-- .../media_artifact_file_output_mixin.py | 2 +- griptape/mixins/rule_mixin.py | 2 +- griptape/mixins/serializable_mixin.py | 2 +- griptape/rules/rule.py | 2 +- griptape/rules/ruleset.py | 2 +- griptape/structures/agent.py | 2 +- griptape/structures/pipeline.py | 2 +- griptape/structures/workflow.py | 2 +- griptape/tasks/actions_subtask.py | 2 +- griptape/tasks/base_audio_generation_task.py | 2 +- griptape/tasks/base_image_generation_task.py | 2 +- griptape/tasks/base_multi_text_input_task.py | 2 +- griptape/tasks/base_task.py | 2 +- griptape/tasks/base_text_input_task.py | 2 +- griptape/tasks/code_execution_task.py | 2 +- griptape/tasks/csv_extraction_task.py | 2 +- griptape/tasks/extraction_task.py | 2 +- griptape/tasks/image_query_task.py | 2 +- .../tasks/inpainting_image_generation_task.py | 2 +- .../outpainting_image_generation_task.py | 2 +- .../tasks/prompt_image_generation_task.py | 2 +- griptape/tasks/prompt_task.py | 2 +- griptape/tasks/structure_run_task.py | 2 +- griptape/tasks/text_query_task.py | 2 +- griptape/tasks/text_summary_task.py | 2 +- griptape/tasks/text_to_speech_task.py | 2 +- griptape/tasks/tool_task.py | 2 +- griptape/tasks/toolkit_task.py | 2 +- .../tasks/variation_image_generation_task.py | 2 +- griptape/tokenizers/anthropic_tokenizer.py | 2 +- griptape/tokenizers/base_tokenizer.py | 2 +- .../tokenizers/bedrock_claude_tokenizer.py | 2 +- .../tokenizers/bedrock_cohere_tokenizer.py | 2 +- .../tokenizers/bedrock_jurassic_tokenizer.py | 2 +- .../tokenizers/bedrock_llama_tokenizer.py | 2 +- .../tokenizers/bedrock_titan_tokenizer.py | 2 +- griptape/tokenizers/cohere_tokenizer.py | 2 +- griptape/tokenizers/google_tokenizer.py | 2 +- griptape/tokenizers/huggingface_tokenizer.py | 2 +- griptape/tokenizers/openai_tokenizer.py | 2 +- griptape/tokenizers/simple_tokenizer.py | 2 +- griptape/tokenizers/voyageai_tokenizer.py | 2 +- griptape/tools/aws_iam_client/tool.py | 2 +- griptape/tools/aws_s3_client/tool.py | 2 +- griptape/tools/base_aws_client.py | 2 +- griptape/tools/base_google_client.py | 2 +- griptape/tools/base_griptape_cloud_client.py | 2 +- griptape/tools/base_tool.py | 2 +- griptape/tools/computer/tool.py | 2 +- griptape/tools/email_client/tool.py | 2 +- griptape/tools/file_manager/tool.py | 2 +- griptape/tools/google_cal/tool.py | 2 +- griptape/tools/google_docs/tool.py | 2 +- griptape/tools/google_drive/tool.py | 2 +- griptape/tools/google_gmail/tool.py | 2 +- .../tool.py | 2 +- griptape/tools/image_query_client/tool.py | 2 +- griptape/tools/openweather_client/tool.py | 2 +- griptape/tools/rest_api_client/tool.py | 6 +-- griptape/tools/sql_client/tool.py | 2 +- griptape/tools/structure_run_client/tool.py | 2 +- griptape/tools/task_memory_client/tool.py | 2 +- griptape/tools/vector_store_client/tool.py | 2 +- griptape/tools/web_scraper/tool.py | 2 +- griptape/tools/web_search/tool.py | 2 +- griptape/utils/chat.py | 2 +- griptape/utils/command_runner.py | 2 +- griptape/utils/conversation.py | 2 +- griptape/utils/j2.py | 2 +- griptape/utils/prompt_stack.py | 2 +- griptape/utils/python_runner.py | 2 +- griptape/utils/token_counter.py | 2 +- poetry.lock | 39 +++++++++---------- pyproject.toml | 8 +++- tests/mocks/invalid_mock_tool/tool.py | 2 +- tests/mocks/mock_embedding_driver.py | 2 +- tests/mocks/mock_event_listener_driver.py | 2 +- tests/mocks/mock_failing_prompt_driver.py | 4 +- tests/mocks/mock_image_generation_driver.py | 2 +- tests/mocks/mock_image_generation_task.py | 2 +- tests/mocks/mock_image_query_driver.py | 2 +- tests/mocks/mock_multi_text_input_task.py | 2 +- tests/mocks/mock_prompt_driver.py | 3 +- tests/mocks/mock_task.py | 2 +- tests/mocks/mock_text_input_task.py | 2 +- tests/mocks/mock_tokenizer.py | 2 +- tests/mocks/mock_tool/tool.py | 2 +- tests/mocks/mock_value_prompt_driver.py | 2 +- .../artifacts/test_base_media_artifact.py | 2 +- tests/utils/structure_tester.py | 2 +- 227 files changed, 257 insertions(+), 251 deletions(-) diff --git a/griptape/artifacts/audio_artifact.py b/griptape/artifacts/audio_artifact.py index b6d4667e4..3dc67fa36 100644 --- a/griptape/artifacts/audio_artifact.py +++ b/griptape/artifacts/audio_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define +from attrs import define from griptape.artifacts import MediaArtifact diff --git a/griptape/artifacts/base_artifact.py b/griptape/artifacts/base_artifact.py index daaa32c4f..a7a1811ea 100644 --- a/griptape/artifacts/base_artifact.py +++ b/griptape/artifacts/base_artifact.py @@ -4,7 +4,7 @@ import json import uuid from abc import ABC, abstractmethod -from attr import define, field, Factory +from attrs import define, field, Factory @define() diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 7b5173736..5d2f32272 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -1,7 +1,7 @@ from __future__ import annotations import os.path from typing import Optional -from attr import field, define +from attrs import field, define from griptape.artifacts import BaseArtifact diff --git a/griptape/artifacts/csv_row_artifact.py b/griptape/artifacts/csv_row_artifact.py index 14fc1c0f9..7572b7a2f 100644 --- a/griptape/artifacts/csv_row_artifact.py +++ b/griptape/artifacts/csv_row_artifact.py @@ -1,7 +1,7 @@ from __future__ import annotations import csv import io -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, BaseArtifact diff --git a/griptape/artifacts/error_artifact.py b/griptape/artifacts/error_artifact.py index f2fb02a8f..1900002b3 100644 --- a/griptape/artifacts/error_artifact.py +++ b/griptape/artifacts/error_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import BaseArtifact diff --git a/griptape/artifacts/image_artifact.py b/griptape/artifacts/image_artifact.py index 068cf6d00..e963b3881 100644 --- a/griptape/artifacts/image_artifact.py +++ b/griptape/artifacts/image_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import MediaArtifact diff --git a/griptape/artifacts/info_artifact.py b/griptape/artifacts/info_artifact.py index 19b50a043..3692e9631 100644 --- a/griptape/artifacts/info_artifact.py +++ b/griptape/artifacts/info_artifact.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import BaseArtifact diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 3a8c93d50..558f32432 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -1,6 +1,6 @@ from typing import Optional from collections.abc import Sequence -from attr import field, define +from attrs import field, define from griptape.artifacts import BaseArtifact diff --git a/griptape/artifacts/media_artifact.py b/griptape/artifacts/media_artifact.py index dc78ea352..92cc0d9cd 100644 --- a/griptape/artifacts/media_artifact.py +++ b/griptape/artifacts/media_artifact.py @@ -5,7 +5,7 @@ import random from typing import Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import BlobArtifact import base64 diff --git a/griptape/artifacts/text_artifact.py b/griptape/artifacts/text_artifact.py index ef0567d03..e8a2bb2a7 100644 --- a/griptape/artifacts/text_artifact.py +++ b/griptape/artifacts/text_artifact.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import BaseArtifact if TYPE_CHECKING: diff --git a/griptape/chunkers/base_chunker.py b/griptape/chunkers/base_chunker.py index c9700a91f..f2cc452ad 100644 --- a/griptape/chunkers/base_chunker.py +++ b/griptape/chunkers/base_chunker.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC from typing import Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.chunkers import ChunkSeparator from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index d716205c8..a94cf75d9 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -3,7 +3,7 @@ from abc import ABC from typing import Optional -from attr import define, field +from attrs import define, field from griptape.config import BaseConfig from griptape.drivers import ( diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py index e3161b38a..15ce67c4c 100644 --- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import json from typing import Any, TYPE_CHECKING -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import BedrockCohereTokenizer from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py index f3e0cb7f7..a510c618c 100644 --- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import json from typing import Any, TYPE_CHECKING -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import BedrockTitanTokenizer from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py index 376836bbd..4ab6d2bf7 100644 --- a/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py @@ -3,7 +3,7 @@ import json from typing import Any -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.drivers import BaseMultiModelEmbeddingDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/embedding/azure_openai_embedding_driver.py b/griptape/drivers/embedding/azure_openai_embedding_driver.py index f2c3b8d39..c92197e9b 100644 --- a/griptape/drivers/embedding/azure_openai_embedding_driver.py +++ b/griptape/drivers/embedding/azure_openai_embedding_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Callable, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import OpenAiEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer import openai diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index 6f74b9306..0fcc05fc1 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -2,7 +2,7 @@ import numpy as np from typing import Optional from abc import ABC, abstractmethod -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.mixins import ExponentialBackoffMixin from griptape.tokenizers import BaseTokenizer diff --git a/griptape/drivers/embedding/base_multi_model_embedding_driver.py b/griptape/drivers/embedding/base_multi_model_embedding_driver.py index 792c13c4b..90f827ad2 100644 --- a/griptape/drivers/embedding/base_multi_model_embedding_driver.py +++ b/griptape/drivers/embedding/base_multi_model_embedding_driver.py @@ -2,7 +2,7 @@ from abc import ABC from typing import TYPE_CHECKING -from attr import define, field +from attrs import define, field from griptape.drivers import BaseEmbeddingDriver diff --git a/griptape/drivers/embedding/google_embedding_driver.py b/griptape/drivers/embedding/google_embedding_driver.py index 3467d3360..884a40c3c 100644 --- a/griptape/drivers/embedding/google_embedding_driver.py +++ b/griptape/drivers/embedding/google_embedding_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional -from attr import define, field +from attrs import define, field from griptape.drivers import BaseEmbeddingDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py index 8e90513cc..71abef81f 100644 --- a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py +++ b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from griptape.utils import import_optional_dependency -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import BaseEmbeddingDriver if TYPE_CHECKING: diff --git a/griptape/drivers/embedding/openai_embedding_driver.py b/griptape/drivers/embedding/openai_embedding_driver.py index 676a6209b..089875c1a 100644 --- a/griptape/drivers/embedding/openai_embedding_driver.py +++ b/griptape/drivers/embedding/openai_embedding_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer import openai diff --git a/griptape/drivers/embedding/voyageai_embedding_driver.py b/griptape/drivers/embedding/voyageai_embedding_driver.py index 1f1779b50..0cfac6fda 100644 --- a/griptape/drivers/embedding/voyageai_embedding_driver.py +++ b/griptape/drivers/embedding/voyageai_embedding_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional, Any -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.utils import import_optional_dependency from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import VoyageAiTokenizer diff --git a/griptape/drivers/embedding_model/base_embedding_model_driver.py b/griptape/drivers/embedding_model/base_embedding_model_driver.py index d2b169495..ad7bf3bda 100644 --- a/griptape/drivers/embedding_model/base_embedding_model_driver.py +++ b/griptape/drivers/embedding_model/base_embedding_model_driver.py @@ -1,4 +1,4 @@ -from attr import define +from attrs import define from abc import ABC, abstractmethod diff --git a/griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py b/griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py index 602902c3f..dceffcd8a 100644 --- a/griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py +++ b/griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py @@ -1,4 +1,4 @@ -from attr import define +from attrs import define from griptape.drivers import BaseEmbeddingModelDriver diff --git a/griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py b/griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py index 1b509a698..9d9632fb0 100644 --- a/griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py +++ b/griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py @@ -1,4 +1,4 @@ -from attr import define +from attrs import define from griptape.drivers import BaseEmbeddingModelDriver diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py index 1c8132b67..4c632cb01 100644 --- a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -3,7 +3,7 @@ import json from typing import TYPE_CHECKING, Any -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py index c4fd72084..3b014aed4 100644 --- a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py +++ b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py @@ -3,7 +3,7 @@ import json from typing import TYPE_CHECKING, Any -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py index 9ff18cf4f..b6d2d9b12 100644 --- a/griptape/drivers/event_listener/base_event_listener_driver.py +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -4,7 +4,7 @@ from concurrent import futures from logging import Logger -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.events import BaseEvent diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 2c4149ae7..c481d3081 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -4,7 +4,7 @@ import requests from urllib.parse import urljoin -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver diff --git a/griptape/drivers/event_listener/webhook_event_listener_driver.py b/griptape/drivers/event_listener/webhook_event_listener_driver.py index 242e5428a..a0eb5ab5f 100644 --- a/griptape/drivers/event_listener/webhook_event_listener_driver.py +++ b/griptape/drivers/event_listener/webhook_event_listener_driver.py @@ -2,7 +2,7 @@ import requests -from attr import define, field +from attrs import define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver diff --git a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py index cbf2109ef..bdb4d787c 100644 --- a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py +++ b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py @@ -2,7 +2,7 @@ import os from pathlib import Path from typing import TYPE_CHECKING, Any -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.utils.import_utils import import_optional_dependency from .base_file_manager_driver import BaseFileManagerDriver diff --git a/griptape/drivers/file_manager/base_file_manager_driver.py b/griptape/drivers/file_manager/base_file_manager_driver.py index 1cff06d06..56f19b3cc 100644 --- a/griptape/drivers/file_manager/base_file_manager_driver.py +++ b/griptape/drivers/file_manager/base_file_manager_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact, InfoArtifact, ListArtifact import griptape.loaders as loaders diff --git a/griptape/drivers/file_manager/local_file_manager_driver.py b/griptape/drivers/file_manager/local_file_manager_driver.py index 96766fa9a..186296aa3 100644 --- a/griptape/drivers/file_manager/local_file_manager_driver.py +++ b/griptape/drivers/file_manager/local_file_manager_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import os from pathlib import Path -from attr import define, field, Factory +from attrs import define, field, Factory from .base_file_manager_driver import BaseFileManagerDriver diff --git a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py index e3e72d55e..2edb9f862 100644 --- a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py +++ b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py @@ -3,7 +3,7 @@ import json from typing import TYPE_CHECKING, Any, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import ImageArtifact from griptape.drivers import BaseMultiModelImageGenerationDriver diff --git a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py index f6cbbe1a9..b631703d8 100644 --- a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations import openai -from attr import field, Factory, define +from attrs import field, Factory, define from typing import Callable, Optional from griptape.drivers import OpenAiImageGenerationDriver diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 771fc30dc..dbe42442f 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import ImageArtifact from griptape.events import StartImageGenerationEvent, FinishImageGenerationEvent diff --git a/griptape/drivers/image_generation/base_multi_model_image_generation_driver.py b/griptape/drivers/image_generation/base_multi_model_image_generation_driver.py index 9b1bb9932..12c2fbef5 100644 --- a/griptape/drivers/image_generation/base_multi_model_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_multi_model_image_generation_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from attr import field, define +from attrs import field, define from griptape.drivers import BaseImageGenerationDriver, BaseImageGenerationModelDriver diff --git a/griptape/drivers/image_generation/leonardo_image_generation_driver.py b/griptape/drivers/image_generation/leonardo_image_generation_driver.py index c41b15d97..d274970ee 100644 --- a/griptape/drivers/image_generation/leonardo_image_generation_driver.py +++ b/griptape/drivers/image_generation/leonardo_image_generation_driver.py @@ -3,7 +3,7 @@ from typing import Optional, Literal import requests -from attr import field, define, Factory +from attrs import field, define, Factory from griptape.artifacts import ImageArtifact from griptape.drivers import BaseImageGenerationDriver diff --git a/griptape/drivers/image_generation/openai_image_generation_driver.py b/griptape/drivers/image_generation/openai_image_generation_driver.py index 327132f64..abc961e3d 100644 --- a/griptape/drivers/image_generation/openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/openai_image_generation_driver.py @@ -5,7 +5,7 @@ import openai from openai.types.images_response import ImagesResponse -from attr import field, Factory, define +from attrs import field, Factory, define from griptape.artifacts import ImageArtifact from griptape.drivers import BaseImageGenerationDriver diff --git a/griptape/drivers/image_generation_model/base_image_generation_model_driver.py b/griptape/drivers/image_generation_model/base_image_generation_model_driver.py index 328bbf4d7..803863319 100644 --- a/griptape/drivers/image_generation_model/base_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/base_image_generation_model_driver.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from attr import define +from attrs import define from griptape.artifacts import ImageArtifact from griptape.mixins import SerializableMixin diff --git a/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py b/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py index 5018b42e4..03d593eac 100644 --- a/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py @@ -4,7 +4,7 @@ import logging from typing import Optional -from attr import field, define +from attrs import field, define from griptape.artifacts import ImageArtifact from griptape.drivers import BaseImageGenerationModelDriver diff --git a/griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py b/griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py index 8602b386b..2f4577aa7 100644 --- a/griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py +++ b/griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py @@ -3,7 +3,7 @@ import base64 from typing import Any, Optional -from attr import field, define +from attrs import field, define from griptape.artifacts import ImageArtifact from griptape.drivers import BaseImageGenerationModelDriver diff --git a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py index a5a7c6f15..eabe9d27e 100644 --- a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py +++ b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseMultiModelImageQueryDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/image_query/anthropic_image_query_driver.py b/griptape/drivers/image_query/anthropic_image_query_driver.py index cfb1a1d48..aca06ac2f 100644 --- a/griptape/drivers/image_query/anthropic_image_query_driver.py +++ b/griptape/drivers/image_query/anthropic_image_query_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional, Any -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/image_query/azure_openai_image_query_driver.py b/griptape/drivers/image_query/azure_openai_image_query_driver.py index 017c98f1f..f59c64823 100644 --- a/griptape/drivers/image_query/azure_openai_image_query_driver.py +++ b/griptape/drivers/image_query/azure_openai_image_query_driver.py @@ -2,7 +2,7 @@ from typing import Callable, Optional -from attr import define, field, Factory +from attrs import define, field, Factory import openai from griptape.drivers.image_query.openai_image_query_driver import OpenAiImageQueryDriver diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 8f7425732..8944dd931 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -3,7 +3,7 @@ from abc import abstractmethod, ABC from typing import Optional, TYPE_CHECKING -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, ImageArtifact from griptape.events import StartImageQueryEvent, FinishImageQueryEvent diff --git a/griptape/drivers/image_query/base_multi_model_image_query_driver.py b/griptape/drivers/image_query/base_multi_model_image_query_driver.py index 07d5b7c27..d801fd917 100644 --- a/griptape/drivers/image_query/base_multi_model_image_query_driver.py +++ b/griptape/drivers/image_query/base_multi_model_image_query_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from attr import field, define +from attrs import field, define from griptape.drivers import BaseImageQueryDriver, BaseImageQueryModelDriver diff --git a/griptape/drivers/image_query/dummy_image_query_driver.py b/griptape/drivers/image_query/dummy_image_query_driver.py index ebf579dd4..ddc6e2318 100644 --- a/griptape/drivers/image_query/dummy_image_query_driver.py +++ b/griptape/drivers/image_query/dummy_image_query_driver.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, ImageArtifact from griptape.drivers import BaseImageQueryDriver diff --git a/griptape/drivers/image_query/openai_image_query_driver.py b/griptape/drivers/image_query/openai_image_query_driver.py index 515bdcc7c..8b0020c2c 100644 --- a/griptape/drivers/image_query/openai_image_query_driver.py +++ b/griptape/drivers/image_query/openai_image_query_driver.py @@ -2,7 +2,7 @@ from typing import Literal, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from openai.types.chat import ( ChatCompletionUserMessageParam, ChatCompletionContentPartParam, diff --git a/griptape/drivers/image_query_model/base_image_query_model_driver.py b/griptape/drivers/image_query_model/base_image_query_model_driver.py index 9bddbb979..746a9f84c 100644 --- a/griptape/drivers/image_query_model/base_image_query_model_driver.py +++ b/griptape/drivers/image_query_model/base_image_query_model_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from attr import define +from attrs import define from griptape.artifacts import TextArtifact, ImageArtifact from griptape.mixins import SerializableMixin diff --git a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py index df7dd326a..3d3ca0164 100644 --- a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py +++ b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define +from attrs import define from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryModelDriver diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index 545ad559b..87c3667fe 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from typing import Optional, TYPE_CHECKING, Any from griptape.utils import import_optional_dependency from griptape.drivers import BaseConversationMemoryDriver diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index b22c1c896..9e3f4ffc9 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -1,5 +1,5 @@ import os -from attr import define, field +from attrs import define, field from typing import Optional from griptape.drivers import BaseConversationMemoryDriver from griptape.memory.structure import BaseConversationMemory diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py index 0531d5f9d..3de8737b8 100644 --- a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations import uuid -from attr import define, field, Factory +from attrs import define, field, Factory from typing import Optional, TYPE_CHECKING from griptape.drivers import BaseConversationMemoryDriver from griptape.memory.structure import BaseConversationMemory diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index b645da39a..0675e7f92 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -2,7 +2,7 @@ import json from typing import TYPE_CHECKING, Any from collections.abc import Iterator -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.utils import import_optional_dependency from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver diff --git a/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py index bca558c72..342faf909 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py @@ -2,7 +2,7 @@ import json from typing import TYPE_CHECKING, Any from collections.abc import Iterator -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.utils import import_optional_dependency from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 35b411ac5..486233643 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional, Any from collections.abc import Iterator -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.utils import PromptStack, import_optional_dependency from griptape.drivers import BasePromptDriver diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index 0583369c4..41c91cb65 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -1,4 +1,4 @@ -from attr import define, field, Factory +from attrs import define, field, Factory from typing import Callable, Optional from griptape.utils import PromptStack from griptape.drivers import OpenAiChatPromptDriver diff --git a/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py b/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py index 89f2651ef..4ff2a4902 100644 --- a/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py @@ -1,5 +1,5 @@ from typing import Callable, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import OpenAiCompletionPromptDriver import openai diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index ef6a35bcf..096035f8b 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional, Callable from collections.abc import Iterator -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.events import StartPromptEvent, FinishPromptEvent, CompletionChunkEvent from griptape.mixins.serializable_mixin import SerializableMixin from griptape.utils import PromptStack diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 07ae7717d..2f85c49bf 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from collections.abc import Iterator -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver from griptape.tokenizers import CohereTokenizer diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index b576a440b..9f833c035 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterator from typing import TYPE_CHECKING, Optional, Any -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.utils import PromptStack, import_optional_dependency from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index a1a6f31d5..062672aa8 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -3,7 +3,7 @@ from collections.abc import Iterator from typing import TYPE_CHECKING -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 9fb8b9ece..bde6d5e4e 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -1,6 +1,6 @@ from collections.abc import Iterator -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index eee6e2d97..3d19063d3 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -2,7 +2,7 @@ from typing import Optional, Any, Literal from collections.abc import Iterator import openai -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver diff --git a/griptape/drivers/prompt/openai_completion_prompt_driver.py b/griptape/drivers/prompt/openai_completion_prompt_driver.py index 92ae4f36f..1a738a487 100644 --- a/griptape/drivers/prompt/openai_completion_prompt_driver.py +++ b/griptape/drivers/prompt/openai_completion_prompt_driver.py @@ -1,6 +1,6 @@ from typing import Optional from collections.abc import Iterator -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver diff --git a/griptape/drivers/prompt_model/base_prompt_model_driver.py b/griptape/drivers/prompt_model/base_prompt_model_driver.py index 9738127db..de9b6b6d6 100644 --- a/griptape/drivers/prompt_model/base_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/base_prompt_model_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver diff --git a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py index 1233a8f04..032940d66 100644 --- a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional import json -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.utils import PromptStack from griptape.drivers import BasePromptModelDriver, AmazonBedrockPromptDriver diff --git a/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py index 47d9f1199..3b51eef1c 100644 --- a/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional import json -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.utils import PromptStack from griptape.drivers import BasePromptModelDriver diff --git a/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py index ae7de03eb..6c6dab1d2 100644 --- a/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py @@ -2,7 +2,7 @@ import json import itertools as it from typing import Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.utils import PromptStack from griptape.drivers import BasePromptModelDriver diff --git a/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py index 5e03e491a..00621416a 100644 --- a/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional import json -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.utils import PromptStack from griptape.drivers import BasePromptModelDriver diff --git a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py index 0fd50386b..a6859f13c 100644 --- a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.utils import PromptStack, import_optional_dependency from griptape.drivers import BasePromptModelDriver diff --git a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py index 31f22427d..7864a70a8 100644 --- a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.utils import PromptStack, import_optional_dependency from griptape.drivers import BasePromptModelDriver diff --git a/griptape/drivers/sql/amazon_redshift_sql_driver.py b/griptape/drivers/sql/amazon_redshift_sql_driver.py index 51acbfa04..e6ec8b2b9 100644 --- a/griptape/drivers/sql/amazon_redshift_sql_driver.py +++ b/griptape/drivers/sql/amazon_redshift_sql_driver.py @@ -2,7 +2,7 @@ import time from typing import Optional, TYPE_CHECKING, Any from griptape.drivers import BaseSqlDriver -from attr import Factory, define, field +from attrs import Factory, define, field if TYPE_CHECKING: import boto3 diff --git a/griptape/drivers/sql/base_sql_driver.py b/griptape/drivers/sql/base_sql_driver.py index 95d1d0fa2..564746389 100644 --- a/griptape/drivers/sql/base_sql_driver.py +++ b/griptape/drivers/sql/base_sql_driver.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Any -from attr import define +from attrs import define @define diff --git a/griptape/drivers/sql/snowflake_sql_driver.py b/griptape/drivers/sql/snowflake_sql_driver.py index c4519b6c8..4c85ab3a3 100644 --- a/griptape/drivers/sql/snowflake_sql_driver.py +++ b/griptape/drivers/sql/snowflake_sql_driver.py @@ -2,7 +2,7 @@ from typing import Callable, Optional, TYPE_CHECKING, Any from griptape.utils import import_optional_dependency from griptape.drivers import BaseSqlDriver -from attr import Factory, define, field +from attrs import Factory, define, field if TYPE_CHECKING: from sqlalchemy.engine import Engine diff --git a/griptape/drivers/sql/sql_driver.py b/griptape/drivers/sql/sql_driver.py index ad6d5637c..a789a90c3 100644 --- a/griptape/drivers/sql/sql_driver.py +++ b/griptape/drivers/sql/sql_driver.py @@ -2,7 +2,7 @@ from typing import Optional, TYPE_CHECKING, Any from griptape.drivers import BaseSqlDriver from griptape.utils import import_optional_dependency -from attr import define, field +from attrs import define, field if TYPE_CHECKING: from sqlalchemy.engine import Engine diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index e1365a69f..4f6dbcf25 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent diff --git a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py index cf342b87b..02d55c023 100644 --- a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py @@ -2,7 +2,7 @@ from typing import Any -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts.audio_artifact import AudioArtifact from griptape.drivers import BaseTextToSpeechDriver diff --git a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py index e6022b29d..2d6e7b155 100644 --- a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py @@ -3,7 +3,7 @@ from typing import Optional, Literal import openai -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts.audio_artifact import AudioArtifact from griptape.drivers import BaseTextToSpeechDriver diff --git a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py index 15ff16546..668f89b7e 100644 --- a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from typing import Optional, TYPE_CHECKING from griptape.drivers import OpenSearchVectorStoreDriver from griptape.utils import import_optional_dependency, str_to_hash diff --git a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py index 47eb7d0d1..2e9968b63 100644 --- a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py +++ b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional -from attr import define +from attrs import define from griptape.drivers import BaseVectorStoreDriver, MongoDbAtlasVectorStoreDriver diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 0667df8a9..55a758da1 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from concurrent import futures from dataclasses import dataclass -from attr import define, field, Factory +from attrs import define, field, Factory from typing import Optional from griptape import utils from griptape.mixins import SerializableMixin diff --git a/griptape/drivers/vector/local_vector_store_driver.py b/griptape/drivers/vector/local_vector_store_driver.py index bea8dc579..4b9dae00f 100644 --- a/griptape/drivers/vector/local_vector_store_driver.py +++ b/griptape/drivers/vector/local_vector_store_driver.py @@ -3,7 +3,7 @@ from numpy.linalg import norm from griptape import utils from griptape.drivers import BaseVectorStoreDriver -from attr import define, field +from attrs import define, field @define diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index 100c75d79..cccc716f8 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -3,7 +3,7 @@ from griptape.utils import import_optional_dependency from griptape.drivers import BaseVectorStoreDriver from griptape.artifacts import TextArtifact -from attr import define, field, Factory +from attrs import define, field, Factory if TYPE_CHECKING: import marqo diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index df47dc5e1..c5368de3d 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index 41bae7254..7ff00d2de 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -4,7 +4,7 @@ import logging from griptape.utils import import_optional_dependency from griptape.drivers import BaseVectorStoreDriver -from attr import define, field, Factory +from attrs import define, field, Factory if TYPE_CHECKING: from opensearchpy import OpenSearch diff --git a/griptape/drivers/vector/pgvector_vector_store_driver.py b/griptape/drivers/vector/pgvector_vector_store_driver.py index 29d5afb67..8058c8774 100644 --- a/griptape/drivers/vector/pgvector_vector_store_driver.py +++ b/griptape/drivers/vector/pgvector_vector_store_driver.py @@ -1,6 +1,6 @@ import uuid from typing import Optional, Any, cast -from attr import define, field, Factory +from attrs import define, field, Factory from dataclasses import dataclass from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index 7b573ed68..daf378087 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -2,7 +2,7 @@ from typing import Optional, TYPE_CHECKING, Any from griptape.utils import str_to_hash, import_optional_dependency from griptape.drivers import BaseVectorStoreDriver -from attr import define, field +from attrs import define, field if TYPE_CHECKING: import pinecone diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index b347bfa41..6e39775b7 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -4,7 +4,7 @@ import numpy as np from griptape.utils import import_optional_dependency, str_to_hash from typing import Optional, TYPE_CHECKING -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import BaseVectorStoreDriver logging.basicConfig(level=logging.WARNING) diff --git a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py index 3b4931579..eb33aeb19 100644 --- a/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/markdownify_web_scraper_driver.py @@ -1,6 +1,6 @@ import re from typing import Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.drivers import BaseWebScraperDriver from griptape.utils import import_optional_dependency diff --git a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py index 787dd24d7..cc0ca3a18 100644 --- a/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py +++ b/griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py @@ -1,6 +1,6 @@ import json import logging -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.drivers import BaseWebScraperDriver from griptape.utils import import_optional_dependency diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index 29118848e..634837d82 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact from griptape.drivers import BaseTextToSpeechDriver diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 760a77d39..e8bd29129 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional from abc import ABC, abstractmethod -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import ListArtifact, ErrorArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.drivers import BasePromptDriver diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 696abd6d3..fe4d0e6c7 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -2,7 +2,7 @@ from typing import Optional, cast import csv import io -from attr import field, Factory, define +from attrs import field, Factory, define from griptape.artifacts import TextArtifact, CsvRowArtifact, ListArtifact, ErrorArtifact from griptape.utils import PromptStack from griptape.engines import BaseExtractionEngine diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index e14e40d5a..05db19d40 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional, cast import json -from attr import field, Factory, define +from attrs import field, Factory, define from griptape.artifacts import TextArtifact, ListArtifact, ErrorArtifact from griptape.engines import BaseExtractionEngine from griptape.utils import J2 diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 73dd5ac74..2c65c1a60 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from attr import field, define +from attrs import field, define from typing import Optional from griptape.artifacts import ImageArtifact diff --git a/griptape/engines/image/inpainting_image_generation_engine.py b/griptape/engines/image/inpainting_image_generation_engine.py index 3f1295866..7fb83d66f 100644 --- a/griptape/engines/image/inpainting_image_generation_engine.py +++ b/griptape/engines/image/inpainting_image_generation_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define +from attrs import define from typing import Optional from griptape.engines import BaseImageGenerationEngine diff --git a/griptape/engines/image/outpainting_image_generation_engine.py b/griptape/engines/image/outpainting_image_generation_engine.py index 0d4008543..135e3c77d 100644 --- a/griptape/engines/image/outpainting_image_generation_engine.py +++ b/griptape/engines/image/outpainting_image_generation_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define +from attrs import define from typing import Optional from griptape.artifacts import ImageArtifact diff --git a/griptape/engines/image/prompt_image_generation_engine.py b/griptape/engines/image/prompt_image_generation_engine.py index c7d304beb..4b4a9ce63 100644 --- a/griptape/engines/image/prompt_image_generation_engine.py +++ b/griptape/engines/image/prompt_image_generation_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define +from attrs import define from typing import Optional from griptape.rules import Ruleset diff --git a/griptape/engines/image/variation_image_generation_engine.py b/griptape/engines/image/variation_image_generation_engine.py index e07a6fc9d..7b932ac2a 100644 --- a/griptape/engines/image/variation_image_generation_engine.py +++ b/griptape/engines/image/variation_image_generation_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define +from attrs import define from typing import Optional from griptape.engines import BaseImageGenerationEngine diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index 4cbac9755..9cb61cd92 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryDriver diff --git a/griptape/engines/query/base_query_engine.py b/griptape/engines/query/base_query_engine.py index 9477b3bd0..61a488547 100644 --- a/griptape/engines/query/base_query_engine.py +++ b/griptape/engines/query/base_query_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from attr import define +from attrs import define from typing import Optional from griptape.artifacts import TextArtifact, ListArtifact from griptape.rules import Ruleset diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index 34667d592..adfa4b2db 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact from griptape.utils import PromptStack from griptape.engines import BaseQueryEngine diff --git a/griptape/engines/summary/base_summary_engine.py b/griptape/engines/summary/base_summary_engine.py index 4a5aca520..26a5cd46d 100644 --- a/griptape/engines/summary/base_summary_engine.py +++ b/griptape/engines/summary/base_summary_engine.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from typing import Optional -from attr import define +from attrs import define from griptape.artifacts import TextArtifact, ListArtifact from griptape.rules import Ruleset diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index bb74cfd6f..0da99cb0a 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -1,5 +1,5 @@ from typing import Optional, cast -from attr import define, Factory, field +from attrs import define, Factory, field from griptape.artifacts import TextArtifact, ListArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.utils import PromptStack diff --git a/griptape/events/base_event.py b/griptape/events/base_event.py index 48a48890e..9ab8e6c47 100644 --- a/griptape/events/base_event.py +++ b/griptape/events/base_event.py @@ -4,7 +4,7 @@ import uuid from abc import ABC -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.mixins import SerializableMixin diff --git a/griptape/events/base_image_query_event.py b/griptape/events/base_image_query_event.py index 8ed796072..b634f2bf0 100644 --- a/griptape/events/base_image_query_event.py +++ b/griptape/events/base_image_query_event.py @@ -1,6 +1,6 @@ from abc import ABC -from attr import define +from attrs import define from griptape.events import BaseEvent diff --git a/griptape/events/completion_chunk_event.py b/griptape/events/completion_chunk_event.py index 829b07c1b..a4244bd5d 100644 --- a/griptape/events/completion_chunk_event.py +++ b/griptape/events/completion_chunk_event.py @@ -1,4 +1,4 @@ -from attr import field +from attrs import field from attrs import define from griptape.events.base_event import BaseEvent diff --git a/griptape/events/finish_image_query_event.py b/griptape/events/finish_image_query_event.py index bdcad2ca3..3eb2e7ccb 100644 --- a/griptape/events/finish_image_query_event.py +++ b/griptape/events/finish_image_query_event.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from griptape.events.base_image_query_event import BaseImageQueryEvent diff --git a/griptape/events/start_image_generation_event.py b/griptape/events/start_image_generation_event.py index a8f2d6b2a..c673f34c4 100644 --- a/griptape/events/start_image_generation_event.py +++ b/griptape/events/start_image_generation_event.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Optional -from attr import define, field +from attrs import define, field from .base_image_generation_event import BaseImageGenerationEvent diff --git a/griptape/events/start_image_query_event.py b/griptape/events/start_image_query_event.py index e72579b04..46001a43f 100644 --- a/griptape/events/start_image_query_event.py +++ b/griptape/events/start_image_query_event.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from griptape.events.base_image_query_event import BaseImageQueryEvent diff --git a/griptape/events/start_text_to_speech_event.py b/griptape/events/start_text_to_speech_event.py index 9824a9827..4c3f27ca0 100644 --- a/griptape/events/start_text_to_speech_event.py +++ b/griptape/events/start_text_to_speech_event.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from .base_text_to_speech_event import BaseTextToSpeechEvent diff --git a/griptape/loaders/base_loader.py b/griptape/loaders/base_loader.py index 56c3e840c..1648b8f26 100644 --- a/griptape/loaders/base_loader.py +++ b/griptape/loaders/base_loader.py @@ -5,7 +5,7 @@ from typing import Any, Optional from collections.abc import Mapping, Sequence -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import BaseArtifact from griptape.utils.futures import execute_futures_dict diff --git a/griptape/loaders/blob_loader.py b/griptape/loaders/blob_loader.py index c85c63bcb..8b9ea4bf9 100644 --- a/griptape/loaders/blob_loader.py +++ b/griptape/loaders/blob_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Any, Union, cast -from attr import define +from attrs import define from griptape.artifacts import BlobArtifact, ErrorArtifact from griptape.loaders import BaseLoader diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 72fdcc73a..d396f80bc 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -3,7 +3,7 @@ from io import StringIO from typing import Optional, Union, cast -from attr import define, field +from attrs import define, field from griptape.artifacts import CsvRowArtifact, ErrorArtifact from griptape.drivers import BaseEmbeddingDriver diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 1c652d8d2..3d5a0f48b 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -2,7 +2,7 @@ from typing import Optional, TYPE_CHECKING, cast -from attr import define, field +from attrs import define, field from griptape.artifacts import CsvRowArtifact from griptape.drivers import BaseEmbeddingDriver diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index dbb6f3afe..a7598e2a6 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -4,7 +4,7 @@ import logging import imaplib -from attr import astuple, define, field +from attrs import astuple, define, field from griptape.utils import import_optional_dependency from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact diff --git a/griptape/loaders/image_loader.py b/griptape/loaders/image_loader.py index 71908a3bc..ca0ba776a 100644 --- a/griptape/loaders/image_loader.py +++ b/griptape/loaders/image_loader.py @@ -3,7 +3,7 @@ from io import BytesIO from typing import Optional, cast -from attr import define, field +from attrs import define, field from griptape.utils import import_optional_dependency from griptape.artifacts import ImageArtifact diff --git a/griptape/loaders/pdf_loader.py b/griptape/loaders/pdf_loader.py index 66b1e1a01..8e4560cfc 100644 --- a/griptape/loaders/pdf_loader.py +++ b/griptape/loaders/pdf_loader.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from attr import define, field, Factory +from attrs import define, field, Factory from typing import Optional, Union, cast from griptape.artifacts.error_artifact import ErrorArtifact diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 527b74de4..c20fb022d 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -1,6 +1,6 @@ from typing import Optional, cast -from attr import define, field +from attrs import define, field from griptape.artifacts import CsvRowArtifact from griptape.drivers import BaseSqlDriver, BaseEmbeddingDriver diff --git a/griptape/loaders/text_loader.py b/griptape/loaders/text_loader.py index 3bba6023d..e8a80fa64 100644 --- a/griptape/loaders/text_loader.py +++ b/griptape/loaders/text_loader.py @@ -2,7 +2,7 @@ from typing import Optional, Union, cast -from attr import field, define, Factory +from attrs import field, define, Factory from griptape.artifacts import TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index d6e38f7e0..f8862f7cd 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts.error_artifact import ErrorArtifact from griptape.drivers import BaseWebScraperDriver, TrafilaturaWebScraperDriver from griptape.artifacts import TextArtifact diff --git a/griptape/memory/meta/action_subtask_meta_entry.py b/griptape/memory/meta/action_subtask_meta_entry.py index 19c60e8fa..6b5124971 100644 --- a/griptape/memory/meta/action_subtask_meta_entry.py +++ b/griptape/memory/meta/action_subtask_meta_entry.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional -from attr import field, define +from attrs import field, define from griptape.memory.meta import BaseMetaEntry diff --git a/griptape/memory/meta/base_meta_entry.py b/griptape/memory/meta/base_meta_entry.py index a7be50dda..c79ec4731 100644 --- a/griptape/memory/meta/base_meta_entry.py +++ b/griptape/memory/meta/base_meta_entry.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define +from attrs import define from abc import ABC from griptape.mixins import SerializableMixin diff --git a/griptape/memory/meta/meta_memory.py b/griptape/memory/meta/meta_memory.py index 05d256455..214e6e285 100644 --- a/griptape/memory/meta/meta_memory.py +++ b/griptape/memory/meta/meta_memory.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from griptape.memory.meta import BaseMetaEntry diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 85c90d3e1..6db05c92c 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.memory.structure import Run from griptape.utils import PromptStack from griptape.mixins import SerializableMixin diff --git a/griptape/memory/structure/conversation_memory.py b/griptape/memory/structure/conversation_memory.py index 71e41c431..94e73d80c 100644 --- a/griptape/memory/structure/conversation_memory.py +++ b/griptape/memory/structure/conversation_memory.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define +from attrs import define from typing import Optional from griptape.memory.structure import Run, BaseConversationMemory from griptape.utils import PromptStack diff --git a/griptape/memory/structure/run.py b/griptape/memory/structure/run.py index f725bb9f4..c5a2b9b55 100644 --- a/griptape/memory/structure/run.py +++ b/griptape/memory/structure/run.py @@ -1,5 +1,5 @@ import uuid -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.mixins import SerializableMixin diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index fa06a3c76..e4d5597d5 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.utils import J2, PromptStack from griptape.memory.structure import ConversationMemory diff --git a/griptape/memory/task/storage/base_artifact_storage.py b/griptape/memory/task/storage/base_artifact_storage.py index cb3d963be..fbd226363 100644 --- a/griptape/memory/task/storage/base_artifact_storage.py +++ b/griptape/memory/task/storage/base_artifact_storage.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Any from abc import ABC, abstractmethod -from attr import define +from attrs import define from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact, InfoArtifact diff --git a/griptape/memory/task/storage/blob_artifact_storage.py b/griptape/memory/task/storage/blob_artifact_storage.py index 322d06617..79b5798df 100644 --- a/griptape/memory/task/storage/blob_artifact_storage.py +++ b/griptape/memory/task/storage/blob_artifact_storage.py @@ -1,5 +1,5 @@ from typing import Any -from attr import define, field +from attrs import define, field from griptape.artifacts import BaseArtifact, ListArtifact, BlobArtifact, InfoArtifact from griptape.memory.task.storage import BaseArtifactStorage diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8e4423f54..3b3162751 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact from griptape.memory.task.storage import BaseArtifactStorage diff --git a/griptape/memory/task/task_memory.py b/griptape/memory/task/task_memory.py index 597b151c0..2f1fdbe16 100644 --- a/griptape/memory/task/task_memory.py +++ b/griptape/memory/task/task_memory.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, Any, Callable -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, ErrorArtifact, TextArtifact from griptape.memory.meta import ActionSubtaskMetaEntry from griptape.mixins import ActivityMixin diff --git a/griptape/mixins/actions_subtask_origin_mixin.py b/griptape/mixins/actions_subtask_origin_mixin.py index 01711140b..22f06a1c6 100644 --- a/griptape/mixins/actions_subtask_origin_mixin.py +++ b/griptape/mixins/actions_subtask_origin_mixin.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from abc import abstractmethod -from attr import define +from attrs import define from schema import Schema, Literal if TYPE_CHECKING: diff --git a/griptape/mixins/activity_mixin.py b/griptape/mixins/activity_mixin.py index a28a8b412..73810864f 100644 --- a/griptape/mixins/activity_mixin.py +++ b/griptape/mixins/activity_mixin.py @@ -1,6 +1,6 @@ import inspect from typing import Optional, Callable -from attr import define, field +from attrs import define, field from jinja2 import Template from schema import Schema, Literal diff --git a/griptape/mixins/exponential_backoff_mixin.py b/griptape/mixins/exponential_backoff_mixin.py index 6559b7460..5045575f1 100644 --- a/griptape/mixins/exponential_backoff_mixin.py +++ b/griptape/mixins/exponential_backoff_mixin.py @@ -1,8 +1,8 @@ import logging from abc import ABC -from attr import define, field +from attrs import define, field from tenacity import Retrying, wait_exponential, stop_after_attempt, retry_if_not_exception_type -from typing import Tuple, Type, Callable +from typing import Callable @define(slots=False) @@ -11,7 +11,7 @@ class ExponentialBackoffMixin(ABC): max_retry_delay: float = field(default=10, kw_only=True) max_attempts: int = field(default=10, kw_only=True) after_hook: Callable = field(default=lambda s: logging.warning(s), kw_only=True) - ignored_exception_types: Tuple[Type[Exception], ...] = field(factory=tuple, kw_only=True) + ignored_exception_types: tuple[type[Exception], ...] = field(factory=tuple, kw_only=True) def retrying(self) -> Retrying: return Retrying( diff --git a/griptape/mixins/media_artifact_file_output_mixin.py b/griptape/mixins/media_artifact_file_output_mixin.py index d7d6f584c..14dcd4898 100644 --- a/griptape/mixins/media_artifact_file_output_mixin.py +++ b/griptape/mixins/media_artifact_file_output_mixin.py @@ -3,7 +3,7 @@ import os from typing import TYPE_CHECKING -from attr import define, field +from attrs import define, field from typing import Optional if TYPE_CHECKING: diff --git a/griptape/mixins/rule_mixin.py b/griptape/mixins/rule_mixin.py index 4c7c98225..a89309a15 100644 --- a/griptape/mixins/rule_mixin.py +++ b/griptape/mixins/rule_mixin.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.rules import Ruleset, Rule diff --git a/griptape/mixins/serializable_mixin.py b/griptape/mixins/serializable_mixin.py index ee80f5efa..c7a0bf035 100644 --- a/griptape/mixins/serializable_mixin.py +++ b/griptape/mixins/serializable_mixin.py @@ -3,7 +3,7 @@ import json from typing import TypeVar, Generic, cast, Optional -from attr import Factory, define, field +from attrs import Factory, define, field from abc import ABC from marshmallow import Schema diff --git a/griptape/rules/rule.py b/griptape/rules/rule.py index 5239396c8..f2a33c7e5 100644 --- a/griptape/rules/rule.py +++ b/griptape/rules/rule.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define +from attrs import define @define(frozen=True) diff --git a/griptape/rules/ruleset.py b/griptape/rules/ruleset.py index 71c0a78e7..9a78b58b8 100644 --- a/griptape/rules/ruleset.py +++ b/griptape/rules/ruleset.py @@ -1,4 +1,4 @@ -from attr import field, define +from attrs import field, define from griptape.rules import Rule diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index aee95dbf8..79b831f63 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.tools import BaseTool from griptape.memory.structure import Run from griptape.structures import Structure diff --git a/griptape/structures/pipeline.py b/griptape/structures/pipeline.py index 8db5cfae9..00c5f1d09 100644 --- a/griptape/structures/pipeline.py +++ b/griptape/structures/pipeline.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, Any -from attr import define +from attrs import define from griptape.artifacts import ErrorArtifact from griptape.memory.structure import Run from griptape.structures import Structure diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 3d59b656d..e60efa425 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -2,7 +2,7 @@ import concurrent.futures as futures from graphlib import TopologicalSorter from typing import Any -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import ErrorArtifact from griptape.structures import Structure from griptape.tasks import BaseTask diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index ef2c2ce6f..d47d1df32 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -4,7 +4,7 @@ from typing import Optional, TYPE_CHECKING, Callable import schema -from attr import define, field +from attrs import define, field from griptape import utils from griptape.utils import remove_null_values_in_dict_recursively from griptape.mixins import ActionsSubtaskOriginMixin diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index 2e5572aba..d401af0a5 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -2,7 +2,7 @@ from abc import ABC -from attr import define +from attrs import define from griptape.mixins import RuleMixin, BlobArtifactFileOutputMixin from griptape.tasks import BaseTask diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 2dbab4ce9..75f57b711 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -3,7 +3,7 @@ import os from abc import ABC -from attr import field, define +from attrs import field, define from griptape.artifacts import MediaArtifact from griptape.loaders import ImageLoader diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index eb00af6ca..314cdb596 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -3,7 +3,7 @@ from abc import ABC from typing import Callable -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.mixins.rule_mixin import RuleMixin diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 037a3e526..79552ce66 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional from collections.abc import Sequence -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.events import StartTaskEvent, FinishTaskEvent from griptape.artifacts import ErrorArtifact diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 3dd5d7009..c5641bb14 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -3,7 +3,7 @@ from abc import ABC from typing import Callable -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.mixins.rule_mixin import RuleMixin diff --git a/griptape/tasks/code_execution_task.py b/griptape/tasks/code_execution_task.py index a1255093d..038642b76 100644 --- a/griptape/tasks/code_execution_task.py +++ b/griptape/tasks/code_execution_task.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import BaseArtifact, ErrorArtifact from griptape.tasks import BaseTextInputTask from typing import Callable diff --git a/griptape/tasks/csv_extraction_task.py b/griptape/tasks/csv_extraction_task.py index 2f5f3db56..8770187f0 100644 --- a/griptape/tasks/csv_extraction_task.py +++ b/griptape/tasks/csv_extraction_task.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.engines import CsvExtractionEngine from griptape.tasks import ExtractionTask diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index 76c0be395..03838ef6c 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import ListArtifact, ErrorArtifact from griptape.engines import BaseExtractionEngine from griptape.tasks import BaseTextInputTask diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index 94be4f483..f6791f5a6 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attr import define, field +from attrs import define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.engines import ImageQueryEngine diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index f3b2edb7a..374820d01 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attr import define, field +from attrs import define, field from griptape.engines import InpaintingImageGenerationEngine from griptape.artifacts import ImageArtifact, TextArtifact diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index fd2d335e5..a35671de3 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attr import define, field +from attrs import define, field from griptape.engines import OutpaintingImageGenerationEngine from griptape.artifacts import ImageArtifact, TextArtifact diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 93404ef84..577abacbb 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attr import define, field +from attrs import define, field from griptape.engines import PromptImageGenerationEngine from griptape.artifacts import ImageArtifact, TextArtifact diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 16f7c6dac..75051db74 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, Callable -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.utils import PromptStack from griptape.utils import J2 from griptape.tasks import BaseTextInputTask diff --git a/griptape/tasks/structure_run_task.py b/griptape/tasks/structure_run_task.py index b7a8f9ea9..012f5c235 100644 --- a/griptape/tasks/structure_run_task.py +++ b/griptape/tasks/structure_run_task.py @@ -1,7 +1,7 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import BaseArtifact from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver diff --git a/griptape/tasks/text_query_task.py b/griptape/tasks/text_query_task.py index 5fee68103..f35161c17 100644 --- a/griptape/tasks/text_query_task.py +++ b/griptape/tasks/text_query_task.py @@ -1,4 +1,4 @@ -from attr import define, field, Factory +from attrs import define, field, Factory from typing import Optional from griptape.artifacts import TextArtifact from griptape.engines import BaseQueryEngine, VectorQueryEngine diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index f10f851d0..648fb1cf1 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.engines import PromptSummaryEngine from griptape.tasks import BaseTextInputTask diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index 8a69227c5..d69e907b1 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attr import define, field +from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact from griptape.engines import TextToSpeechEngine diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index f875186ea..edd90c26e 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -2,7 +2,7 @@ import re import json from typing import Optional, TYPE_CHECKING -from attr import define, field +from attrs import define, field from schema import Schema from griptape import utils diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index ead020c47..ed787aa45 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -1,7 +1,7 @@ from __future__ import annotations import json from typing import TYPE_CHECKING, Callable, Optional -from attr import define, field, Factory +from attrs import define, field, Factory from schema import Schema from griptape import utils diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index 1242bc59b..0d1269840 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attr import define, field +from attrs import define, field from griptape.engines import VariationImageGenerationEngine from griptape.artifacts import ImageArtifact, TextArtifact diff --git a/griptape/tokenizers/anthropic_tokenizer.py b/griptape/tokenizers/anthropic_tokenizer.py index 3aefbc980..577df7b93 100644 --- a/griptape/tokenizers/anthropic_tokenizer.py +++ b/griptape/tokenizers/anthropic_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from typing import TYPE_CHECKING from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer diff --git a/griptape/tokenizers/base_tokenizer.py b/griptape/tokenizers/base_tokenizer.py index 28f30e66c..179d2fb59 100644 --- a/griptape/tokenizers/base_tokenizer.py +++ b/griptape/tokenizers/base_tokenizer.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from attr import define, field, Factory +from attrs import define, field, Factory from griptape import utils diff --git a/griptape/tokenizers/bedrock_claude_tokenizer.py b/griptape/tokenizers/bedrock_claude_tokenizer.py index d5ff9722b..d44116e2c 100644 --- a/griptape/tokenizers/bedrock_claude_tokenizer.py +++ b/griptape/tokenizers/bedrock_claude_tokenizer.py @@ -1,4 +1,4 @@ -from attr import define +from attrs import define from griptape.tokenizers import AnthropicTokenizer diff --git a/griptape/tokenizers/bedrock_cohere_tokenizer.py b/griptape/tokenizers/bedrock_cohere_tokenizer.py index 772453eb5..44ccb4ac6 100644 --- a/griptape/tokenizers/bedrock_cohere_tokenizer.py +++ b/griptape/tokenizers/bedrock_cohere_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from .simple_tokenizer import SimpleTokenizer diff --git a/griptape/tokenizers/bedrock_jurassic_tokenizer.py b/griptape/tokenizers/bedrock_jurassic_tokenizer.py index 525e2e211..7511138b3 100644 --- a/griptape/tokenizers/bedrock_jurassic_tokenizer.py +++ b/griptape/tokenizers/bedrock_jurassic_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from .simple_tokenizer import SimpleTokenizer diff --git a/griptape/tokenizers/bedrock_llama_tokenizer.py b/griptape/tokenizers/bedrock_llama_tokenizer.py index 051b4552d..e7d1ec829 100644 --- a/griptape/tokenizers/bedrock_llama_tokenizer.py +++ b/griptape/tokenizers/bedrock_llama_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from .simple_tokenizer import SimpleTokenizer diff --git a/griptape/tokenizers/bedrock_titan_tokenizer.py b/griptape/tokenizers/bedrock_titan_tokenizer.py index 403fa2039..0d8ba0273 100644 --- a/griptape/tokenizers/bedrock_titan_tokenizer.py +++ b/griptape/tokenizers/bedrock_titan_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from .simple_tokenizer import SimpleTokenizer diff --git a/griptape/tokenizers/cohere_tokenizer.py b/griptape/tokenizers/cohere_tokenizer.py index 4856b7ff7..0a3c6a236 100644 --- a/griptape/tokenizers/cohere_tokenizer.py +++ b/griptape/tokenizers/cohere_tokenizer.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -from attr import define, field +from attrs import define, field from griptape.tokenizers import BaseTokenizer if TYPE_CHECKING: diff --git a/griptape/tokenizers/google_tokenizer.py b/griptape/tokenizers/google_tokenizer.py index 26d1f16b0..55942f597 100644 --- a/griptape/tokenizers/google_tokenizer.py +++ b/griptape/tokenizers/google_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from typing import TYPE_CHECKING from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer diff --git a/griptape/tokenizers/huggingface_tokenizer.py b/griptape/tokenizers/huggingface_tokenizer.py index c3203d4a9..dbfba5429 100644 --- a/griptape/tokenizers/huggingface_tokenizer.py +++ b/griptape/tokenizers/huggingface_tokenizer.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.tokenizers import BaseTokenizer if TYPE_CHECKING: diff --git a/griptape/tokenizers/openai_tokenizer.py b/griptape/tokenizers/openai_tokenizer.py index dda8bfe15..ec127ca1a 100644 --- a/griptape/tokenizers/openai_tokenizer.py +++ b/griptape/tokenizers/openai_tokenizer.py @@ -1,6 +1,6 @@ from __future__ import annotations import logging -from attr import define +from attrs import define import tiktoken from griptape.tokenizers import BaseTokenizer from typing import Optional diff --git a/griptape/tokenizers/simple_tokenizer.py b/griptape/tokenizers/simple_tokenizer.py index d97c3bd89..484afe69f 100644 --- a/griptape/tokenizers/simple_tokenizer.py +++ b/griptape/tokenizers/simple_tokenizer.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional -from attr import define, field +from attrs import define, field from griptape.tokenizers import BaseTokenizer diff --git a/griptape/tokenizers/voyageai_tokenizer.py b/griptape/tokenizers/voyageai_tokenizer.py index 799c64276..565e53faa 100644 --- a/griptape/tokenizers/voyageai_tokenizer.py +++ b/griptape/tokenizers/voyageai_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from typing import TYPE_CHECKING, Optional from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer diff --git a/griptape/tools/aws_iam_client/tool.py b/griptape/tools/aws_iam_client/tool.py index 0c94be90e..467540829 100644 --- a/griptape/tools/aws_iam_client/tool.py +++ b/griptape/tools/aws_iam_client/tool.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from schema import Schema, Literal -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact, ErrorArtifact, ListArtifact from griptape.utils.decorators import activity from griptape.tools import BaseAwsClient diff --git a/griptape/tools/aws_s3_client/tool.py b/griptape/tools/aws_s3_client/tool.py index fc63f3988..9d50c7cf8 100644 --- a/griptape/tools/aws_s3_client/tool.py +++ b/griptape/tools/aws_s3_client/tool.py @@ -2,7 +2,7 @@ import io from typing import Any, TYPE_CHECKING from schema import Schema, Literal -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import TextArtifact, ErrorArtifact, InfoArtifact, ListArtifact, BlobArtifact from griptape.utils.decorators import activity from griptape.tools import BaseAwsClient diff --git a/griptape/tools/base_aws_client.py b/griptape/tools/base_aws_client.py index cb6159a34..bd7b41e59 100644 --- a/griptape/tools/base_aws_client.py +++ b/griptape/tools/base_aws_client.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from abc import ABC -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, ErrorArtifact, BaseArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity diff --git a/griptape/tools/base_google_client.py b/griptape/tools/base_google_client.py index b0da963b5..51ae2f946 100644 --- a/griptape/tools/base_google_client.py +++ b/griptape/tools/base_google_client.py @@ -1,5 +1,5 @@ from abc import ABC -from attr import define, field +from attrs import define, field from griptape.tools import BaseTool from typing import Optional, Any diff --git a/griptape/tools/base_griptape_cloud_client.py b/griptape/tools/base_griptape_cloud_client.py index cafe01cd1..8c42f817e 100644 --- a/griptape/tools/base_griptape_cloud_client.py +++ b/griptape/tools/base_griptape_cloud_client.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.tools import BaseTool diff --git a/griptape/tools/base_tool.py b/griptape/tools/base_tool.py index 2c9f74839..4f1255665 100644 --- a/griptape/tools/base_tool.py +++ b/griptape/tools/base_tool.py @@ -9,7 +9,7 @@ from abc import ABC from typing import Optional import yaml -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import BaseArtifact, InfoArtifact, TextArtifact from griptape.mixins import ActivityMixin diff --git a/griptape/tools/computer/tool.py b/griptape/tools/computer/tool.py index 57953f005..b63ba9e0b 100644 --- a/griptape/tools/computer/tool.py +++ b/griptape/tools/computer/tool.py @@ -5,7 +5,7 @@ import tempfile from pathlib import Path from typing import Optional, TYPE_CHECKING -from attr import define, field, Factory +from attrs import define, field, Factory from docker.models.containers import Container from schema import Schema, Literal import stringcase diff --git a/griptape/tools/email_client/tool.py b/griptape/tools/email_client/tool.py index b67a16e4c..e2b42ea7e 100644 --- a/griptape/tools/email_client/tool.py +++ b/griptape/tools/email_client/tool.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import Factory, define, field +from attrs import Factory, define, field from email.mime.text import MIMEText from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact from griptape.loaders.email_loader import EmailLoader diff --git a/griptape/tools/file_manager/tool.py b/griptape/tools/file_manager/tool.py index 6aa2c85c0..162b546a6 100644 --- a/griptape/tools/file_manager/tool.py +++ b/griptape/tools/file_manager/tool.py @@ -1,6 +1,6 @@ from __future__ import annotations import os -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.drivers import BaseFileManagerDriver, LocalFileManagerDriver from griptape.tools import BaseTool diff --git a/griptape/tools/google_cal/tool.py b/griptape/tools/google_cal/tool.py index 7ca83e3a5..b18c5f41c 100644 --- a/griptape/tools/google_cal/tool.py +++ b/griptape/tools/google_cal/tool.py @@ -2,7 +2,7 @@ import logging import datetime from schema import Schema, Literal, Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, ErrorArtifact, InfoArtifact, ListArtifact from griptape.utils.decorators import activity from griptape.tools import BaseGoogleClient diff --git a/griptape/tools/google_docs/tool.py b/griptape/tools/google_docs/tool.py index 2aaffa13e..48dbbe39c 100644 --- a/griptape/tools/google_docs/tool.py +++ b/griptape/tools/google_docs/tool.py @@ -1,6 +1,6 @@ from __future__ import annotations import logging -from attr import field, define +from attrs import field, define from schema import Schema, Optional, Literal from griptape.artifacts import ErrorArtifact, InfoArtifact diff --git a/griptape/tools/google_drive/tool.py b/griptape/tools/google_drive/tool.py index 761b38754..fcb042cfb 100644 --- a/griptape/tools/google_drive/tool.py +++ b/griptape/tools/google_drive/tool.py @@ -3,7 +3,7 @@ from typing import Any, Optional import schema from schema import Schema, Literal, Or -from attr import define, field +from attrs import define, field from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact, BlobArtifact, TextArtifact from griptape.utils.decorators import activity from griptape.tools import BaseGoogleClient diff --git a/griptape/tools/google_gmail/tool.py b/griptape/tools/google_gmail/tool.py index 7539b6668..5b4bf5cd5 100644 --- a/griptape/tools/google_gmail/tool.py +++ b/griptape/tools/google_gmail/tool.py @@ -3,7 +3,7 @@ import base64 from email.message import EmailMessage from schema import Schema, Literal -from attr import define, field +from attrs import define, field from griptape.artifacts import InfoArtifact, ErrorArtifact from griptape.utils.decorators import activity from griptape.tools import BaseGoogleClient diff --git a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py b/griptape/tools/griptape_cloud_knowledge_base_client/tool.py index 89a284c98..917406cdd 100644 --- a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py +++ b/griptape/tools/griptape_cloud_knowledge_base_client/tool.py @@ -2,7 +2,7 @@ from typing import Optional from urllib.parse import urljoin from schema import Schema, Literal -from attr import define, field +from attrs import define, field from griptape.tools.base_griptape_cloud_client import BaseGriptapeCloudClient from griptape.utils.decorators import activity from griptape.artifacts import TextArtifact, ErrorArtifact diff --git a/griptape/tools/image_query_client/tool.py b/griptape/tools/image_query_client/tool.py index eab3b5e3e..60f2970ff 100644 --- a/griptape/tools/image_query_client/tool.py +++ b/griptape/tools/image_query_client/tool.py @@ -2,7 +2,7 @@ from typing import Any, cast -from attr import define, field, Factory +from attrs import define, field, Factory from schema import Schema, Literal from griptape.artifacts import TextArtifact, ImageArtifact, ErrorArtifact, BlobArtifact diff --git a/griptape/tools/openweather_client/tool.py b/griptape/tools/openweather_client/tool.py index 875fc2cf4..e338677dd 100644 --- a/griptape/tools/openweather_client/tool.py +++ b/griptape/tools/openweather_client/tool.py @@ -4,7 +4,7 @@ from griptape.utils.decorators import activity from schema import Schema, Literal from typing import Optional -from attr import define, field +from attrs import define, field import requests import logging diff --git a/griptape/tools/rest_api_client/tool.py b/griptape/tools/rest_api_client/tool.py index 97bd9384e..e33971cc2 100644 --- a/griptape/tools/rest_api_client/tool.py +++ b/griptape/tools/rest_api_client/tool.py @@ -1,9 +1,9 @@ from textwrap import dedent -from typing import Optional, Dict +from typing import Optional from urllib.parse import urljoin import schema from schema import Schema, Literal -from attr import define, field +from attrs import define, field from griptape.tools import BaseTool from griptape.utils.decorators import activity from griptape.artifacts import BaseArtifact, TextArtifact, ErrorArtifact @@ -30,7 +30,7 @@ class RestApiClient(BaseTool): request_query_params_schema: Optional[str] = field(default=None, kw_only=True) request_body_schema: Optional[str] = field(default=None, kw_only=True) response_body_schema: Optional[str] = field(default=None, kw_only=True) - request_headers: Optional[Dict[str, str]] = field(default=None, kw_only=True) + request_headers: Optional[dict[str, str]] = field(default=None, kw_only=True) @property def full_url(self) -> str: diff --git a/griptape/tools/sql_client/tool.py b/griptape/tools/sql_client/tool.py index e5ed9263e..a07c2e734 100644 --- a/griptape/tools/sql_client/tool.py +++ b/griptape/tools/sql_client/tool.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional -from attr import define, field +from attrs import define, field from griptape.artifacts import InfoArtifact, ListArtifact, ErrorArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity diff --git a/griptape/tools/structure_run_client/tool.py b/griptape/tools/structure_run_client/tool.py index c62b53e97..f48e84a1d 100644 --- a/griptape/tools/structure_run_client/tool.py +++ b/griptape/tools/structure_run_client/tool.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from schema import Literal, Schema from griptape.artifacts import BaseArtifact, TextArtifact diff --git a/griptape/tools/task_memory_client/tool.py b/griptape/tools/task_memory_client/tool.py index b7234aca2..2a85a7c0f 100644 --- a/griptape/tools/task_memory_client/tool.py +++ b/griptape/tools/task_memory_client/tool.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from schema import Schema, Literal from griptape.artifacts import TextArtifact, ErrorArtifact, InfoArtifact from griptape.tools import BaseTool diff --git a/griptape/tools/vector_store_client/tool.py b/griptape/tools/vector_store_client/tool.py index 90ba845c0..f2dc785b0 100644 --- a/griptape/tools/vector_store_client/tool.py +++ b/griptape/tools/vector_store_client/tool.py @@ -1,7 +1,7 @@ from typing import Optional from griptape.engines import VectorQueryEngine from schema import Schema, Literal -from attr import define, field +from attrs import define, field from griptape.artifacts import BaseArtifact, ErrorArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity diff --git a/griptape/tools/web_scraper/tool.py b/griptape/tools/web_scraper/tool.py index bff46473f..c42dedad9 100644 --- a/griptape/tools/web_scraper/tool.py +++ b/griptape/tools/web_scraper/tool.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.artifacts import ErrorArtifact, ListArtifact from schema import Schema, Literal from griptape.tools import BaseTool diff --git a/griptape/tools/web_search/tool.py b/griptape/tools/web_search/tool.py index af8a75d11..6ad25ebf0 100644 --- a/griptape/tools/web_search/tool.py +++ b/griptape/tools/web_search/tool.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, ErrorArtifact, ListArtifact from schema import Schema, Literal from griptape.tools import BaseTool diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 549d93e53..12653ce9e 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, Callable -from attr import Factory, define, field +from attrs import Factory, define, field from griptape.utils.stream import Stream if TYPE_CHECKING: diff --git a/griptape/utils/command_runner.py b/griptape/utils/command_runner.py index c1ad9c9fe..bbc03ec39 100644 --- a/griptape/utils/command_runner.py +++ b/griptape/utils/command_runner.py @@ -1,5 +1,5 @@ import subprocess -from attr import define +from attrs import define from griptape.artifacts import BaseArtifact, TextArtifact, ErrorArtifact diff --git a/griptape/utils/conversation.py b/griptape/utils/conversation.py index 38b3e30a9..2d87563ae 100644 --- a/griptape/utils/conversation.py +++ b/griptape/utils/conversation.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -from attr import define, field +from attrs import define, field if TYPE_CHECKING: from griptape.memory.structure import ConversationMemory diff --git a/griptape/utils/j2.py b/griptape/utils/j2.py index 6b57e274c..ca54fed9e 100644 --- a/griptape/utils/j2.py +++ b/griptape/utils/j2.py @@ -1,5 +1,5 @@ from typing import Optional -from attr import define, field, Factory +from attrs import define, field, Factory from jinja2 import Environment, FileSystemLoader from .paths import abs_path diff --git a/griptape/utils/prompt_stack.py b/griptape/utils/prompt_stack.py index a5f336030..f04cef486 100644 --- a/griptape/utils/prompt_stack.py +++ b/griptape/utils/prompt_stack.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field +from attrs import define, field from griptape.mixins import SerializableMixin diff --git a/griptape/utils/python_runner.py b/griptape/utils/python_runner.py index 0c54ad652..75b6e53c2 100644 --- a/griptape/utils/python_runner.py +++ b/griptape/utils/python_runner.py @@ -1,7 +1,7 @@ import importlib import sys from io import StringIO -from attr import define, field +from attrs import define, field @define diff --git a/griptape/utils/token_counter.py b/griptape/utils/token_counter.py index eccae3afd..2732d95f1 100644 --- a/griptape/utils/token_counter.py +++ b/griptape/utils/token_counter.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field @define diff --git a/poetry.lock b/poetry.lock index c4e2e1024..c78b72391 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4236,7 +4236,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4512,28 +4511,28 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.4.4" +version = "0.4.6" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.4.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:29d44ef5bb6a08e235c8249294fa8d431adc1426bfda99ed493119e6f9ea1bf6"}, - {file = "ruff-0.4.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c4efe62b5bbb24178c950732ddd40712b878a9b96b1d02b0ff0b08a090cbd891"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c8e2f1e8fc12d07ab521a9005d68a969e167b589cbcaee354cb61e9d9de9c15"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:60ed88b636a463214905c002fa3eaab19795679ed55529f91e488db3fe8976ab"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b90fc5e170fc71c712cc4d9ab0e24ea505c6a9e4ebf346787a67e691dfb72e85"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8e7e6ebc10ef16dcdc77fd5557ee60647512b400e4a60bdc4849468f076f6eef"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b9ddb2c494fb79fc208cd15ffe08f32b7682519e067413dbaf5f4b01a6087bcd"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c51c928a14f9f0a871082603e25a1588059b7e08a920f2f9fa7157b5bf08cfe9"}, - {file = "ruff-0.4.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5eb0a4bfd6400b7d07c09a7725e1a98c3b838be557fee229ac0f84d9aa49c36"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b1867ee9bf3acc21778dcb293db504692eda5f7a11a6e6cc40890182a9f9e595"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1aecced1269481ef2894cc495647392a34b0bf3e28ff53ed95a385b13aa45768"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9da73eb616b3241a307b837f32756dc20a0b07e2bcb694fec73699c93d04a69e"}, - {file = "ruff-0.4.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:958b4ea5589706a81065e2a776237de2ecc3e763342e5cc8e02a4a4d8a5e6f95"}, - {file = "ruff-0.4.4-py3-none-win32.whl", hash = "sha256:cb53473849f011bca6e754f2cdf47cafc9c4f4ff4570003a0dad0b9b6890e876"}, - {file = "ruff-0.4.4-py3-none-win_amd64.whl", hash = "sha256:424e5b72597482543b684c11def82669cc6b395aa8cc69acc1858b5ef3e5daae"}, - {file = "ruff-0.4.4-py3-none-win_arm64.whl", hash = "sha256:39df0537b47d3b597293edbb95baf54ff5b49589eb7ff41926d8243caa995ea6"}, - {file = "ruff-0.4.4.tar.gz", hash = "sha256:f87ea42d5cdebdc6a69761a9d0bc83ae9b3b30d0ad78952005ba6568d6c022af"}, + {file = "ruff-0.4.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ef995583a038cd4a7edf1422c9e19118e2511b8ba0b015861b4abd26ec5367c5"}, + {file = "ruff-0.4.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:602ebd7ad909eab6e7da65d3c091547781bb06f5f826974a53dbe563d357e53c"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f9ced5cbb7510fd7525448eeb204e0a22cabb6e99a3cb160272262817d49786"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04a80acfc862e0e1630c8b738e70dcca03f350bad9e106968a8108379e12b31f"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be47700ecb004dfa3fd4dcdddf7322d4e632de3c06cd05329d69c45c0280e618"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1ff930d6e05f444090a0139e4e13e1e2e1f02bd51bb4547734823c760c621e79"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f13410aabd3b5776f9c5699f42b37a3a348d65498c4310589bc6e5c548dc8a2f"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0cf5cc02d3ae52dfb0c8a946eb7a1d6ffe4d91846ffc8ce388baa8f627e3bd50"}, + {file = "ruff-0.4.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea3424793c29906407e3cf417f28fc33f689dacbbadfb52b7e9a809dd535dcef"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1fa8561489fadf483ffbb091ea94b9c39a00ed63efacd426aae2f197a45e67fc"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4d5b914818d8047270308fe3e85d9d7f4a31ec86c6475c9f418fbd1624d198e0"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:4f02284335c766678778475e7698b7ab83abaf2f9ff0554a07b6f28df3b5c259"}, + {file = "ruff-0.4.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3a6a0a4f4b5f54fff7c860010ab3dd81425445e37d35701a965c0248819dde7a"}, + {file = "ruff-0.4.6-py3-none-win32.whl", hash = "sha256:9018bf59b3aa8ad4fba2b1dc0299a6e4e60a4c3bc62bbeaea222679865453062"}, + {file = "ruff-0.4.6-py3-none-win_amd64.whl", hash = "sha256:a769ae07ac74ff1a019d6bd529426427c3e30d75bdf1e08bb3d46ac8f417326a"}, + {file = "ruff-0.4.6-py3-none-win_arm64.whl", hash = "sha256:735a16407a1a8f58e4c5b913ad6102722e80b562dd17acb88887685ff6f20cf6"}, + {file = "ruff-0.4.6.tar.gz", hash = "sha256:a797a87da50603f71e6d0765282098245aca6e3b94b7c17473115167d8dfb0b7"}, ] [[package]] @@ -6022,4 +6021,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "2590fc0528b584775d753939abbde032c34055a92db2538385002b9808d1fa31" +content-hash = "ce7c88b4d4ea368bd7a6c08c8e0f4310b2c10f1237d77d80f50bda4b35612481" diff --git a/pyproject.toml b/pyproject.toml index 5ca064399..af264a5dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,7 +155,7 @@ pytest-clarity = "^1.0.1" optional = true [tool.poetry.group.dev.dependencies] -ruff = "^0.4.4" +ruff = "^0.4.6" pyright = "^1.1.363" pre-commit = "^3.7.1" boto3-stubs = {extras = ["bedrock", "iam", "opensearch", "s3", "sagemaker"], version = "^1.34.105"} @@ -180,6 +180,12 @@ line-length = 120 [tool.ruff.format] skip-magic-trailing-comma = true +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "TID251"] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"attr".msg = "The attr module is deprecated, use attrs instead." + [tool.pyright] venvPath = "." venv = ".venv" diff --git a/tests/mocks/invalid_mock_tool/tool.py b/tests/mocks/invalid_mock_tool/tool.py index 3700d28eb..91b2f78f7 100644 --- a/tests/mocks/invalid_mock_tool/tool.py +++ b/tests/mocks/invalid_mock_tool/tool.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from schema import Schema, Literal from griptape.tools import BaseTool from griptape.utils.decorators import activity diff --git a/tests/mocks/mock_embedding_driver.py b/tests/mocks/mock_embedding_driver.py index 9502421a9..e21c56308 100644 --- a/tests/mocks/mock_embedding_driver.py +++ b/tests/mocks/mock_embedding_driver.py @@ -1,4 +1,4 @@ -from attr import field, define +from attrs import field, define from griptape.drivers import BaseEmbeddingDriver from tests.mocks.mock_tokenizer import MockTokenizer diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py index dd54eeb73..560fb8733 100644 --- a/tests/mocks/mock_event_listener_driver.py +++ b/tests/mocks/mock_event_listener_driver.py @@ -1,4 +1,4 @@ -from attr import define +from attrs import define from griptape.drivers import BaseEventListenerDriver diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index 3bc0e1511..c97b25d86 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -1,5 +1,5 @@ -from typing import Iterator -from attr import define +from collections.abc import Iterator +from attrs import define from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver diff --git a/tests/mocks/mock_image_generation_driver.py b/tests/mocks/mock_image_generation_driver.py index 0775ad5e2..de94771e2 100644 --- a/tests/mocks/mock_image_generation_driver.py +++ b/tests/mocks/mock_image_generation_driver.py @@ -1,5 +1,5 @@ from typing import Optional -from attr import define +from attrs import define from griptape.artifacts import ImageArtifact from griptape.drivers.image_generation.base_image_generation_driver import BaseImageGenerationDriver diff --git a/tests/mocks/mock_image_generation_task.py b/tests/mocks/mock_image_generation_task.py index 24c25dc37..1c79b42a9 100644 --- a/tests/mocks/mock_image_generation_task.py +++ b/tests/mocks/mock_image_generation_task.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.tasks import BaseImageGenerationTask diff --git a/tests/mocks/mock_image_query_driver.py b/tests/mocks/mock_image_query_driver.py index b25684178..d3bec164f 100644 --- a/tests/mocks/mock_image_query_driver.py +++ b/tests/mocks/mock_image_query_driver.py @@ -1,5 +1,5 @@ from typing import Optional -from attr import define +from attrs import define from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryDriver from griptape.drivers.image_generation.base_image_generation_driver import BaseImageGenerationDriver diff --git a/tests/mocks/mock_multi_text_input_task.py b/tests/mocks/mock_multi_text_input_task.py index 1da645ee4..7ab5aedf9 100644 --- a/tests/mocks/mock_multi_text_input_task.py +++ b/tests/mocks/mock_multi_text_input_task.py @@ -1,4 +1,4 @@ -from attr import define +from attrs import define from griptape.artifacts import TextArtifact from griptape.tasks import BaseMultiTextInputTask diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index e2018c6f6..20913a965 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,6 +1,7 @@ +from __future__ import annotations from collections.abc import Iterator from typing import Callable -from attr import Factory, define, field +from attrs import define, field from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer diff --git a/tests/mocks/mock_task.py b/tests/mocks/mock_task.py index a3ea7688d..42595f6eb 100644 --- a/tests/mocks/mock_task.py +++ b/tests/mocks/mock_task.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from griptape.artifacts import TextArtifact, BaseArtifact from griptape.tasks import BaseTask diff --git a/tests/mocks/mock_text_input_task.py b/tests/mocks/mock_text_input_task.py index 5c8c8174c..930c77e74 100644 --- a/tests/mocks/mock_text_input_task.py +++ b/tests/mocks/mock_text_input_task.py @@ -1,4 +1,4 @@ -from attr import define +from attrs import define from griptape.artifacts import TextArtifact from griptape.tasks import BaseTextInputTask diff --git a/tests/mocks/mock_tokenizer.py b/tests/mocks/mock_tokenizer.py index 56a5bc5cc..a333f9a13 100644 --- a/tests/mocks/mock_tokenizer.py +++ b/tests/mocks/mock_tokenizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from attr import define, field +from attrs import define, field from griptape.tokenizers import BaseTokenizer diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index ea66b4aa1..266f77c1b 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -1,4 +1,4 @@ -from attr import define, field +from attrs import define, field from schema import Schema, Literal from griptape.artifacts import TextArtifact, ErrorArtifact, BaseArtifact, ListArtifact from griptape.tools import BaseTool diff --git a/tests/mocks/mock_value_prompt_driver.py b/tests/mocks/mock_value_prompt_driver.py index 8c660602b..12ddeec9f 100644 --- a/tests/mocks/mock_value_prompt_driver.py +++ b/tests/mocks/mock_value_prompt_driver.py @@ -1,5 +1,5 @@ from collections.abc import Iterator -from attr import define, field, Factory +from attrs import define, field, Factory from griptape.drivers import BasePromptDriver from griptape.tokenizers import OpenAiTokenizer, BaseTokenizer from griptape.artifacts import TextArtifact diff --git a/tests/unit/artifacts/test_base_media_artifact.py b/tests/unit/artifacts/test_base_media_artifact.py index d72a8bb02..2829a1e2f 100644 --- a/tests/unit/artifacts/test_base_media_artifact.py +++ b/tests/unit/artifacts/test_base_media_artifact.py @@ -1,6 +1,6 @@ import pytest -from attr import define +from attrs import define from griptape.artifacts import MediaArtifact diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index abd8f0e0a..4f111b8d8 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -1,6 +1,6 @@ from __future__ import annotations import os -from attr import field, define +from attrs import field, define from schema import Schema, Literal import logging import json From 5a0e1c3e0a99d4fab8ad27263c569b644a3b571a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 31 May 2024 11:11:24 -0700 Subject: [PATCH 3/4] Refactor tasks to use ListArtifact as input (#811) --- griptape/artifacts/list_artifact.py | 3 +- griptape/events/base_task_event.py | 7 +-- griptape/events/finish_structure_run_event.py | 7 +-- griptape/events/start_structure_run_event.py | 7 +-- griptape/tasks/base_multi_text_input_task.py | 17 +++--- griptape/tasks/base_task.py | 3 +- griptape/tasks/image_query_task.py | 31 +++++++---- .../tasks/inpainting_image_generation_task.py | 23 +++++--- .../outpainting_image_generation_task.py | 23 +++++--- .../tasks/variation_image_generation_task.py | 17 +++--- tests/unit/artifacts/test_list_artifact.py | 2 +- .../events/test_finish_structure_run_event.py | 14 +++-- .../tasks/test_base_multi_text_input_task.py | 1 - tests/unit/tasks/test_image_query_task.py | 52 +++++++++++++++---- .../test_inpainting_image_generation_task.py | 25 ++++++--- .../test_outpainting_image_generation_task.py | 25 ++++++--- .../test_variation_image_generation_task.py | 22 +++++--- 17 files changed, 188 insertions(+), 91 deletions(-) diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 558f32432..68b377df2 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -8,10 +8,11 @@ class ListArtifact(BaseArtifact): value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True}) item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) + validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) @value.validator # pyright: ignore def validate_value(self, _, value: list[BaseArtifact]) -> None: - if len(value) > 0: + if self.validate_uniform_types and len(value) > 0: first_type = type(value[0]) if not all(isinstance(v, first_type) for v in value): diff --git a/griptape/events/base_task_event.py b/griptape/events/base_task_event.py index c037ec7ab..e853114d5 100644 --- a/griptape/events/base_task_event.py +++ b/griptape/events/base_task_event.py @@ -1,8 +1,7 @@ from __future__ import annotations from attrs import define, field from abc import ABC -from typing import Optional, Union -from collections.abc import Sequence +from typing import Optional from griptape.artifacts import BaseArtifact from .base_event import BaseEvent @@ -13,7 +12,5 @@ class BaseTaskEvent(BaseEvent, ABC): task_parent_ids: list[str] = field(kw_only=True, metadata={"serializable": True}) task_child_ids: list[str] = field(kw_only=True, metadata={"serializable": True}) - task_input: Union[BaseArtifact, tuple[BaseArtifact, ...], tuple[BaseArtifact, Sequence[BaseArtifact]]] = field( - kw_only=True, metadata={"serializable": True} - ) + task_input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) task_output: Optional[BaseArtifact] = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/finish_structure_run_event.py b/griptape/events/finish_structure_run_event.py index 2ff7786d8..588a5be31 100644 --- a/griptape/events/finish_structure_run_event.py +++ b/griptape/events/finish_structure_run_event.py @@ -1,5 +1,4 @@ -from typing import Optional, Union -from collections.abc import Sequence +from typing import Optional from attrs import define, field @@ -10,7 +9,5 @@ @define class FinishStructureRunEvent(BaseEvent): structure_id: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) - output_task_input: Union[BaseArtifact, tuple[BaseArtifact, ...], tuple[BaseArtifact, Sequence[BaseArtifact]]] = ( - field(kw_only=True, metadata={"serializable": True}) - ) + output_task_input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) output_task_output: Optional[BaseArtifact] = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/start_structure_run_event.py b/griptape/events/start_structure_run_event.py index e75c5cd68..f0bb5528c 100644 --- a/griptape/events/start_structure_run_event.py +++ b/griptape/events/start_structure_run_event.py @@ -1,5 +1,4 @@ -from typing import Optional, Union -from collections.abc import Sequence +from typing import Optional from attrs import define, field @@ -10,7 +9,5 @@ @define class StartStructureRunEvent(BaseEvent): structure_id: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) - input_task_input: Union[BaseArtifact, tuple[BaseArtifact, ...], tuple[BaseArtifact, Sequence[BaseArtifact]]] = ( - field(kw_only=True, metadata={"serializable": True}) - ) + input_task_input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) input_task_output: Optional[BaseArtifact] = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index 314cdb596..385bc9b5b 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -5,7 +5,7 @@ from attrs import define, field, Factory -from griptape.artifacts import TextArtifact +from griptape.artifacts import ListArtifact, TextArtifact from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -20,20 +20,19 @@ class BaseMultiTextInputTask(RuleMixin, BaseTask, ABC): ) @property - def input(self) -> tuple[TextArtifact, ...]: + def input(self) -> ListArtifact: if all(isinstance(elem, TextArtifact) for elem in self._input): - return self._input # pyright: ignore + return ListArtifact([artifact for artifact in self._input if isinstance(artifact, TextArtifact)]) elif all(isinstance(elem, Callable) for elem in self._input): - return tuple([elem(self) for elem in self._input]) # pyright: ignore - elif isinstance(self._input, tuple): - return tuple( + return ListArtifact([callable(self) for callable in self._input if isinstance(callable, Callable)]) + else: + return ListArtifact( [ - TextArtifact(J2().render_from_string(input_template, **self.full_context)) # pyright: ignore + TextArtifact(J2().render_from_string(input_template, **self.full_context)) for input_template in self._input + if isinstance(input_template, str) ] ) - else: - return tuple([TextArtifact(J2().render_from_string(self._input, **self.full_context))]) @input.setter def input( diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 79552ce66..771fe4dc8 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -5,7 +5,6 @@ from concurrent import futures from enum import Enum from typing import TYPE_CHECKING, Any, Optional -from collections.abc import Sequence from attrs import define, field, Factory @@ -38,7 +37,7 @@ class State(Enum): @property @abstractmethod - def input(self) -> BaseArtifact | tuple[BaseArtifact, ...] | tuple[BaseArtifact, Sequence[BaseArtifact]]: ... + def input(self) -> BaseArtifact: ... @property def parents(self) -> list[BaseTask]: diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index f6791f5a6..85b25715a 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import ImageQueryEngine from griptape.tasks import BaseTask from griptape.utils import J2 @@ -26,18 +26,21 @@ class ImageQueryTask(BaseTask): _input: ( tuple[str, list[ImageArtifact]] | tuple[TextArtifact, list[ImageArtifact]] - | Callable[[BaseTask], tuple[TextArtifact, list[ImageArtifact]]] + | Callable[[BaseTask], ListArtifact] + | ListArtifact ) = field(default=None, alias="input") @property - def input(self) -> tuple[TextArtifact, list[ImageArtifact]]: - if isinstance(self._input, tuple): + def input(self) -> ListArtifact: + if isinstance(self._input, ListArtifact): + return self._input + elif isinstance(self._input, tuple): if isinstance(self._input[0], TextArtifact): query_text = self._input[0] else: query_text = TextArtifact(J2().render_from_string(self._input[0], **self.full_context)) - return query_text, self._input[1] + return ListArtifact([query_text, *self._input[1]]) elif isinstance(self._input, Callable): return self._input(self) else: @@ -49,8 +52,11 @@ def input(self) -> tuple[TextArtifact, list[ImageArtifact]]: @input.setter def input( self, - value: tuple[TextArtifact, list[ImageArtifact]] - | Callable[[BaseTask], tuple[TextArtifact, list[ImageArtifact]]], + value: ( + tuple[str, list[ImageArtifact]] + | tuple[TextArtifact, list[ImageArtifact]] + | Callable[[BaseTask], ListArtifact] + ), ) -> None: self._input = value @@ -68,8 +74,13 @@ def image_query_engine(self, value: ImageQueryEngine) -> None: self._image_query_engine = value def run(self) -> TextArtifact: - query, image_artifacts = self.input + query = self.input.value[0] - response = self.image_query_engine.run(query.value, image_artifacts) + if all([isinstance(input, ImageArtifact) for input in self.input.value[1:]]): + image_artifacts = [input for input in self.input.value[1:] if isinstance(input, ImageArtifact)] + else: + raise ValueError("All inputs after the query must be ImageArtifacts.") + + self.output = self.image_query_engine.run(query.value, image_artifacts) - return response + return self.output diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 374820d01..b0fae9118 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -5,7 +5,7 @@ from attrs import define, field from griptape.engines import InpaintingImageGenerationEngine -from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts import ImageArtifact, TextArtifact, ListArtifact from griptape.tasks import BaseImageGenerationTask, BaseTask from griptape.utils import J2 @@ -30,26 +30,29 @@ class InpaintingImageGenerationTask(BaseImageGenerationTask): default=None, kw_only=True, alias="image_generation_engine" ) _input: ( - tuple[str | TextArtifact, ImageArtifact, ImageArtifact] - | Callable[[BaseTask], tuple[TextArtifact, ImageArtifact, ImageArtifact]] + tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact ) = field(default=None) @property - def input(self) -> tuple[TextArtifact, ImageArtifact, ImageArtifact]: - if isinstance(self._input, tuple): + def input(self) -> ListArtifact: + if isinstance(self._input, ListArtifact): + return self._input + elif isinstance(self._input, tuple): if isinstance(self._input[0], TextArtifact): input_text = self._input[0] else: input_text = TextArtifact(J2().render_from_string(self._input[0], **self.full_context)) - return input_text, self._input[1], self._input[2] + return ListArtifact([input_text, self._input[1], self._input[2]]) elif isinstance(self._input, Callable): return self._input(self) else: raise ValueError("Input must be a tuple of (text, image, mask) or a callable that returns such a tuple.") @input.setter - def input(self, value: tuple[TextArtifact, ImageArtifact, ImageArtifact]) -> None: + def input( + self, value: tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] + ) -> None: self._input = value @property @@ -69,8 +72,14 @@ def image_generation_engine(self, value: InpaintingImageGenerationEngine) -> Non def run(self) -> ImageArtifact: prompt_artifact = self.input[0] + image_artifact = self.input[1] + if not isinstance(image_artifact, ImageArtifact): + raise ValueError("Image must be an ImageArtifact.") + mask_artifact = self.input[2] + if not isinstance(mask_artifact, ImageArtifact): + raise ValueError("Mask must be an ImageArtifact.") output_image_artifact = self.image_generation_engine.run( prompts=[prompt_artifact.to_text()], diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index a35671de3..61a7c1b8a 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -5,7 +5,7 @@ from attrs import define, field from griptape.engines import OutpaintingImageGenerationEngine -from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts import ImageArtifact, TextArtifact, ListArtifact from griptape.tasks import BaseImageGenerationTask, BaseTask from griptape.utils import J2 @@ -30,26 +30,29 @@ class OutpaintingImageGenerationTask(BaseImageGenerationTask): default=None, kw_only=True, alias="image_generation_engine" ) _input: ( - tuple[str | TextArtifact, ImageArtifact, ImageArtifact] - | Callable[[BaseTask], tuple[TextArtifact, ImageArtifact, ImageArtifact]] + tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact ) = field(default=None) @property - def input(self) -> tuple[TextArtifact, ImageArtifact, ImageArtifact]: - if isinstance(self._input, tuple): + def input(self) -> ListArtifact: + if isinstance(self._input, ListArtifact): + return self._input + elif isinstance(self._input, tuple): if isinstance(self._input[0], TextArtifact): input_text = self._input[0] else: input_text = TextArtifact(J2().render_from_string(self._input[0], **self.full_context)) - return input_text, self._input[1], self._input[2] + return ListArtifact([input_text, self._input[1], self._input[2]]) elif isinstance(self._input, Callable): return self._input(self) else: raise ValueError("Input must be a tuple of (text, image, mask) or a callable that returns such a tuple.") @input.setter - def input(self, value: tuple[TextArtifact, ImageArtifact, ImageArtifact]) -> None: + def input( + self, value: tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] + ) -> None: self._input = value @property @@ -70,8 +73,14 @@ def image_generation_engine(self, value: OutpaintingImageGenerationEngine) -> No def run(self) -> ImageArtifact: prompt_artifact = self.input[0] + image_artifact = self.input[1] + if not isinstance(image_artifact, ImageArtifact): + raise ValueError("Image must be an ImageArtifact.") + mask_artifact = self.input[2] + if not isinstance(mask_artifact, ImageArtifact): + raise ValueError("Mask must be an ImageArtifact.") output_image_artifact = self.image_generation_engine.run( prompts=[prompt_artifact.to_text()], diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index 0d1269840..6efba1e65 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -5,7 +5,7 @@ from attrs import define, field from griptape.engines import VariationImageGenerationEngine -from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts import ImageArtifact, TextArtifact, ListArtifact from griptape.tasks import BaseImageGenerationTask, BaseTask from griptape.utils import J2 @@ -29,26 +29,28 @@ class VariationImageGenerationTask(BaseImageGenerationTask): _image_generation_engine: VariationImageGenerationEngine = field( default=None, kw_only=True, alias="image_generation_engine" ) - _input: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], tuple[TextArtifact, ImageArtifact]] = field( + _input: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact = field( default=None ) @property - def input(self) -> tuple[TextArtifact, ImageArtifact]: - if isinstance(self._input, tuple): + def input(self) -> ListArtifact: + if isinstance(self._input, ListArtifact): + return self._input + elif isinstance(self._input, tuple): if isinstance(self._input[0], TextArtifact): input_text = self._input[0] else: input_text = TextArtifact(J2().render_from_string(self._input[0], **self.full_context)) - return input_text, self._input[1] + return ListArtifact([input_text, self._input[1]]) elif isinstance(self._input, Callable): return self._input(self) else: raise ValueError("Input must be a tuple of (text, image) or a callable that returns such a tuple.") @input.setter - def input(self, value: tuple[TextArtifact, ImageArtifact]) -> None: + def input(self, value: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact]) -> None: self._input = value @property @@ -68,7 +70,10 @@ def image_generation_engine(self, value: VariationImageGenerationEngine) -> None def run(self) -> ImageArtifact: prompt_artifact = self.input[0] + image_artifact = self.input[1] + if not isinstance(image_artifact, ImageArtifact): + raise ValueError("Image must be an ImageArtifact.") output_image_artifact = self.image_generation_engine.run( prompts=[prompt_artifact.to_text()], diff --git a/tests/unit/artifacts/test_list_artifact.py b/tests/unit/artifacts/test_list_artifact.py index cd0183703..044ca8ed5 100644 --- a/tests/unit/artifacts/test_list_artifact.py +++ b/tests/unit/artifacts/test_list_artifact.py @@ -24,7 +24,7 @@ def test___add__(self): def test_validate_value(self): with pytest.raises(ValueError): - ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")]) + ListArtifact([TextArtifact("foo"), BlobArtifact(b"bar")], validate_uniform_types=True) def test_child_type(self): assert ListArtifact([TextArtifact("foo")]).child_type == TextArtifact diff --git a/tests/unit/events/test_finish_structure_run_event.py b/tests/unit/events/test_finish_structure_run_event.py index 68ad1ea01..0e9e61f4f 100644 --- a/tests/unit/events/test_finish_structure_run_event.py +++ b/tests/unit/events/test_finish_structure_run_event.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts.text_artifact import TextArtifact + +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.events import FinishStructureRunEvent @@ -7,12 +8,19 @@ class TestFinishStructureRunEvent: @pytest.fixture def finish_structure_run_event(self): return FinishStructureRunEvent( - structure_id="fizz", output_task_input=TextArtifact("foo"), output_task_output=TextArtifact("bar") + structure_id="fizz", + output_task_input=ListArtifact( + [TextArtifact("foo"), ImageArtifact(b"", format="png", width=100, height=100)] + ), + output_task_output=TextArtifact("bar"), ) def test_to_dict(self, finish_structure_run_event): assert finish_structure_run_event.to_dict() is not None assert finish_structure_run_event.to_dict()["structure_id"] == "fizz" - assert finish_structure_run_event.to_dict()["output_task_input"]["value"] == "foo" + assert finish_structure_run_event.to_dict()["output_task_input"]["value"][0]["value"] == "foo" assert finish_structure_run_event.to_dict()["output_task_output"]["value"] == "bar" + + def test_from_dict(self, finish_structure_run_event): + assert FinishStructureRunEvent.from_dict(finish_structure_run_event.to_dict()) == finish_structure_run_event diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py index 542162757..ad4776aee 100644 --- a/tests/unit/tasks/test_base_multi_text_input_task.py +++ b/tests/unit/tasks/test_base_multi_text_input_task.py @@ -1,7 +1,6 @@ from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.structures import Pipeline from griptape.artifacts import TextArtifact -from griptape.rules import Ruleset, Rule from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index 1d195f835..dd4940213 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -1,14 +1,24 @@ -from griptape.engines import ImageQueryEngine +from unittest.mock import Mock import pytest -from griptape.tasks import BaseTask, ImageQueryTask -from griptape.artifacts import TextArtifact, ImageArtifact + +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.artifacts.list_artifact import ListArtifact +from griptape.engines import ImageQueryEngine from griptape.structures import Agent +from griptape.tasks import BaseTask, ImageQueryTask from tests.mocks.mock_image_query_driver import MockImageQueryDriver from tests.mocks.mock_structure_config import MockStructureConfig class TestImageQueryTask: + @pytest.fixture + def image_query_engine(self) -> Mock: + mock = Mock() + mock.run.return_value = TextArtifact("image") + + return mock + @pytest.fixture def text_artifact(self): return TextArtifact(value="some text") @@ -20,24 +30,34 @@ def image_artifact(self): def test_text_inputs(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): task = ImageQueryTask((text_artifact.value, [image_artifact, image_artifact])) - assert task.input[0].value == text_artifact.value - assert task.input[1] == [image_artifact, image_artifact] + assert task.input.value[0].value == text_artifact.value + assert task.input.value[1] == image_artifact + assert task.input.value[2] == image_artifact def test_artifact_inputs(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): input_tuple = (text_artifact, [image_artifact, image_artifact]) task = ImageQueryTask(input_tuple) - assert task.input == input_tuple + assert task.input.value[0] == text_artifact + assert task.input.value[1] == image_artifact + assert task.input.value[2] == image_artifact def test_callable_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - input_tuple = (text_artifact, [image_artifact, image_artifact]) + input = [text_artifact, image_artifact, image_artifact] - def callable(task: BaseTask) -> tuple[TextArtifact, list[ImageArtifact]]: - return input_tuple + def callable(task: BaseTask) -> ListArtifact: + return ListArtifact(value=input) task = ImageQueryTask(callable) - assert task.input == input_tuple + assert task.input.value == input + + def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): + input = [text_artifact, image_artifact, image_artifact] + + task = ImageQueryTask(ListArtifact(value=input)) + + assert task.input.value == input def test_config_image_generation_engine(self, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) @@ -49,5 +69,15 @@ def test_config_image_generation_engine(self, text_artifact, image_artifact): def test_missing_image_generation_engine(self, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Image Query Engine"): task.image_query_engine + + def test_run(self, image_query_engine, text_artifact, image_artifact): + task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_engine=image_query_engine) + task.run() + + assert task.output.value == "image" + + def test_bad_run(self, image_query_engine, text_artifact, image_artifact): + with pytest.raises(ValueError, match="All inputs"): + ImageQueryTask(("foo", [image_artifact, text_artifact]), image_query_engine=image_query_engine).run() diff --git a/tests/unit/tasks/test_inpainting_image_generation_task.py b/tests/unit/tasks/test_inpainting_image_generation_task.py index 1fd84f2a6..9dc6aff54 100644 --- a/tests/unit/tasks/test_inpainting_image_generation_task.py +++ b/tests/unit/tasks/test_inpainting_image_generation_task.py @@ -1,5 +1,5 @@ +from griptape.artifacts.list_artifact import ListArtifact from griptape.engines import InpaintingImageGenerationEngine -from typing import Tuple from unittest.mock import Mock import pytest @@ -23,17 +23,30 @@ def test_artifact_inputs(self, text_artifact: TextArtifact, image_artifact: Imag input_tuple = (text_artifact, image_artifact, image_artifact) task = InpaintingImageGenerationTask(input_tuple, image_generation_engine=Mock()) - assert task.input == input_tuple + assert task.input.value == list(input_tuple) def test_callable_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - input_tuple = (text_artifact, image_artifact, image_artifact) + input = [text_artifact, image_artifact, image_artifact] - def callable(task: BaseTask) -> tuple[TextArtifact, ImageArtifact, ImageArtifact]: - return input_tuple + def callable(task: BaseTask) -> ListArtifact: + return ListArtifact(value=list(input)) task = InpaintingImageGenerationTask(callable, image_generation_engine=Mock()) - assert task.input == input_tuple + assert task.input.value == input + + def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): + input = [text_artifact, image_artifact] + task = InpaintingImageGenerationTask(ListArtifact(input), image_generation_engine=Mock()) + + assert task.input.value == input + + def test_bad_input(self, image_artifact): + with pytest.raises(ValueError): + InpaintingImageGenerationTask(("foo", "bar", image_artifact)).run() # pyright: ignore[reportArgumentType] + + with pytest.raises(ValueError): + InpaintingImageGenerationTask(("foo", image_artifact, "baz")).run() # pyright: ignore[reportArgumentType] def test_config_image_generation_engine(self, text_artifact, image_artifact): task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) diff --git a/tests/unit/tasks/test_outpainting_image_generation_task.py b/tests/unit/tasks/test_outpainting_image_generation_task.py index f73467258..148ea133d 100644 --- a/tests/unit/tasks/test_outpainting_image_generation_task.py +++ b/tests/unit/tasks/test_outpainting_image_generation_task.py @@ -1,5 +1,5 @@ +from griptape.artifacts.list_artifact import ListArtifact from griptape.engines import OutpaintingImageGenerationEngine -from typing import Tuple from unittest.mock import Mock import pytest @@ -24,17 +24,30 @@ def test_artifact_inputs(self, text_artifact: TextArtifact, image_artifact: Imag input_tuple = (text_artifact, image_artifact, image_artifact) task = OutpaintingImageGenerationTask(input_tuple, image_generation_engine=Mock()) - assert task.input == input_tuple + assert task.input.value == list(input_tuple) def test_callable_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - input_tuple = (text_artifact, image_artifact, image_artifact) + input = [text_artifact, image_artifact, image_artifact] - def callable(task: BaseTask) -> tuple[TextArtifact, ImageArtifact, ImageArtifact]: - return input_tuple + def callable(task: BaseTask) -> ListArtifact: + return ListArtifact(input) task = OutpaintingImageGenerationTask(callable, image_generation_engine=Mock()) - assert task.input == input_tuple + assert task.input.value == input + + def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): + input = [text_artifact, image_artifact] + task = OutpaintingImageGenerationTask(ListArtifact(input), image_generation_engine=Mock()) + + assert task.input.value == input + + def test_bad_input(self, image_artifact): + with pytest.raises(ValueError): + OutpaintingImageGenerationTask(("foo", "bar", image_artifact)).run() # pyright: ignore[reportArgumentType] + + with pytest.raises(ValueError): + OutpaintingImageGenerationTask(("foo", image_artifact, "baz")).run() # pyright: ignore[reportArgumentType] def test_config_image_generation_engine(self, text_artifact, image_artifact): task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) diff --git a/tests/unit/tasks/test_variation_image_generation_task.py b/tests/unit/tasks/test_variation_image_generation_task.py index bef2107ac..6a9533da3 100644 --- a/tests/unit/tasks/test_variation_image_generation_task.py +++ b/tests/unit/tasks/test_variation_image_generation_task.py @@ -1,6 +1,6 @@ +from griptape.artifacts.list_artifact import ListArtifact from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_structure_config import MockStructureConfig -from typing import Tuple from unittest.mock import Mock import pytest @@ -23,17 +23,27 @@ def test_artifact_inputs(self, text_artifact: TextArtifact, image_artifact: Imag input_tuple = (text_artifact, image_artifact) task = VariationImageGenerationTask(input_tuple, image_generation_engine=Mock()) - assert task.input == input_tuple + assert task.input.value == list(input_tuple) def test_callable_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): - input_tuple = (text_artifact, image_artifact) + input = [text_artifact, image_artifact] - def callable(task: BaseTask) -> tuple[TextArtifact, ImageArtifact]: - return input_tuple + def callable(task: BaseTask) -> ListArtifact: + return ListArtifact(input) task = VariationImageGenerationTask(callable, image_generation_engine=Mock()) - assert task.input == input_tuple + assert task.input.value == input + + def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArtifact): + input = [text_artifact, image_artifact] + task = VariationImageGenerationTask(ListArtifact(input), image_generation_engine=Mock()) + + assert task.input.value == input + + def test_bad_input(self, image_artifact): + with pytest.raises(ValueError): + VariationImageGenerationTask(("foo", "bar")).run() # pyright: ignore[reportArgumentType] def test_config_image_generation_engine(self, text_artifact, image_artifact): task = VariationImageGenerationTask((text_artifact, image_artifact)) From 9f787338d0be8f2f8baba1e06af0257ba3baa458 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 31 May 2024 15:07:34 -0700 Subject: [PATCH 4/4] Fix type (#813) --- tests/mocks/mock_prompt_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 20913a965..dc4cde69e 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -19,4 +19,4 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: return TextArtifact(value=self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output) def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - yield TextArtifact(value=self.mock_output) + yield TextArtifact(value=self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output)