From 69bfa3cf630ef69f211ee8c9675eea3119c057c2 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 29 Apr 2024 09:19:15 -0700 Subject: [PATCH 1/9] Refactor how configs work --- docs/griptape-framework/misc/events.md | 2 +- docs/griptape-framework/structures/config.md | 2 +- griptape/config/__init__.py | 14 -- .../config/amazon_bedrock_structure_config.py | 80 ++++----- griptape/config/anthropic_structure_config.py | 55 ++---- griptape/config/base_structure_config.py | 21 ++- griptape/config/google_structure_config.py | 50 ++---- griptape/config/openai_structure_config.py | 56 ++---- griptape/config/structure_config.py | 60 +++---- .../config/structure_task_memory_config.py | 23 --- ...re_task_memory_extraction_engine_config.py | 18 -- ...ask_memory_extraction_engine_csv_config.py | 11 -- ...sk_memory_extraction_engine_json_config.py | 11 -- ...ructure_task_memory_query_engine_config.py | 22 --- ...cture_task_memory_summary_engine_config.py | 11 -- .../structure/summary_conversation_memory.py | 2 +- .../task/storage/text_artifact_storage.py | 8 +- griptape/structures/structure.py | 61 ++----- griptape/tasks/csv_extraction_task.py | 4 +- griptape/tasks/image_query_task.py | 4 +- .../tasks/inpainting_image_generation_task.py | 2 +- griptape/tasks/json_extraction_task.py | 4 +- .../outpainting_image_generation_task.py | 2 +- .../tasks/prompt_image_generation_task.py | 2 +- griptape/tasks/prompt_task.py | 2 +- griptape/tasks/text_query_task.py | 4 +- griptape/tasks/text_summary_task.py | 4 +- .../tasks/variation_image_generation_task.py | 2 +- griptape/utils/chat.py | 4 +- griptape/utils/prompt_stack.py | 2 +- griptape/utils/stream.py | 2 +- tests/mocks/mock_structure_config.py | 59 ++----- .../test_amazon_bedrock_structure_config.py | 134 ++++----------- .../config/test_anthropic_structure_config.py | 112 +++--------- .../config/test_google_structure_config.py | 107 +++--------- .../config/test_openai_structure_config.py | 161 +++++------------- tests/unit/config/test_structure_config.py | 90 ++-------- tests/unit/tasks/test_json_extraction_task.py | 6 +- 38 files changed, 305 insertions(+), 909 deletions(-) delete mode 100644 griptape/config/structure_task_memory_config.py delete mode 100644 griptape/config/structure_task_memory_extraction_engine_config.py delete mode 100644 griptape/config/structure_task_memory_extraction_engine_csv_config.py delete mode 100644 griptape/config/structure_task_memory_extraction_engine_json_config.py delete mode 100644 griptape/config/structure_task_memory_query_engine_config.py delete mode 100644 griptape/config/structure_task_memory_summary_engine_config.py diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index f45f77199..eda790c99 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -162,7 +162,7 @@ from griptape.tools import WebScraper pipeline = Pipeline() -pipeline.config.global_drivers.prompt_driver.stream = True +pipeline.config.prompt_driver.stream = True pipeline.add_tasks(ToolkitTask("Based on https://griptape.ai, tell me what griptape is.", tools=[WebScraper()])) for artifact in Stream(pipeline).run(): diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index c392200f9..b54c426d8 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -122,7 +122,7 @@ from griptape.config import AmazonBedrockStructureConfig from griptape.drivers import AmazonBedrockCohereEmbeddingDriver custom_config = AmazonBedrockStructureConfig() -custom_config.global_drivers.embedding_driver = AmazonBedrockCohereEmbeddingDriver() +custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() custom_config.merge_config( { "task_memory": { diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 273a338f2..849556017 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -1,12 +1,5 @@ from .base_config import BaseConfig -from .structure_global_drivers_config import StructureGlobalDriversConfig -from .structure_task_memory_extraction_engine_csv_config import StructureTaskMemoryExtractionEngineCsvConfig -from .structure_task_memory_extraction_engine_json_config import StructureTaskMemoryExtractionEngineJsonConfig -from .structure_task_memory_extraction_engine_config import StructureTaskMemoryExtractionEngineConfig -from .structure_task_memory_query_engine_config import StructureTaskMemoryQueryEngineConfig -from .structure_task_memory_summary_engine_config import StructureTaskMemorySummaryEngineConfig -from .structure_task_memory_config import StructureTaskMemoryConfig from .base_structure_config import BaseStructureConfig from .structure_config import StructureConfig @@ -19,13 +12,6 @@ __all__ = [ "BaseConfig", "BaseStructureConfig", - "StructureTaskMemoryConfig", - "StructureGlobalDriversConfig", - "StructureTaskMemoryQueryEngineConfig", - "StructureTaskMemorySummaryEngineConfig", - "StructureTaskMemoryExtractionEngineConfig", - "StructureTaskMemoryExtractionEngineCsvConfig", - "StructureTaskMemoryExtractionEngineJsonConfig", "StructureConfig", "OpenAiStructureConfig", "AmazonBedrockStructureConfig", diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_structure_config.py index 54b8d91c7..7caa55990 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_structure_config.py @@ -1,15 +1,6 @@ -from attrs import Factory, define, field +from attrs import define, field, Factory -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) +from griptape.config import StructureConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, AmazonBedrockImageQueryDriver, @@ -23,47 +14,44 @@ @define() -class AmazonBedrockStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( +class AmazonBedrockStructureConfig(StructureConfig): + prompt_driver: AmazonBedrockPromptDriver = field( + default=Factory( + lambda: AmazonBedrockPromptDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", + stream=False, + prompt_model_driver=BedrockClaudePromptModelDriver(), + ) + ), + metadata={"serializable": True}, + ) + image_generation_driver: AmazonBedrockImageGenerationDriver = field( + default=Factory( + lambda: AmazonBedrockImageGenerationDriver( + model="amazon.titan-image-generator-v1", + image_generation_model_driver=BedrockTitanImageGenerationModelDriver(), + ) + ), + metadata={"serializable": True}, + ) + image_query_driver: AmazonBedrockImageQueryDriver = field( default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - stream=False, - prompt_model_driver=BedrockClaudePromptModelDriver(), - ), - image_generation_driver=AmazonBedrockImageGenerationDriver( - model="amazon.titan-image-generator-v1", - image_generation_model_driver=BedrockTitanImageGenerationModelDriver(), - ), - image_query_driver=AmazonBedrockImageQueryDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - image_query_model_driver=BedrockClaudeImageQueryModelDriver(), - ), - embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") - ), + lambda: AmazonBedrockImageQueryDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", + image_query_model_driver=BedrockClaudeImageQueryModelDriver(), ) ), - kw_only=True, metadata={"serializable": True}, ) - task_memory: StructureTaskMemoryConfig = field( + embedding_driver: AmazonBedrockTitanEmbeddingDriver = field( + default=Factory(lambda: AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")), + metadata={"serializable": True}, + ) + vector_store_driver: LocalVectorStoreDriver = field( default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=self.global_drivers.vector_store_driver, - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, + lambda: LocalVectorStoreDriver( + embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") + ) ), - kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_structure_config.py index 06978a5c2..6b147b7fd 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_structure_config.py @@ -1,54 +1,29 @@ -from attrs import Factory, define, field +from attrs import define, field, Factory -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) +from griptape.config import StructureConfig from griptape.drivers import ( - LocalVectorStoreDriver, - AnthropicPromptDriver, AnthropicImageQueryDriver, + AnthropicPromptDriver, + LocalVectorStoreDriver, VoyageAiEmbeddingDriver, ) @define -class AnthropicStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( +class AnthropicStructureConfig(StructureConfig): + prompt_driver: AnthropicPromptDriver = field( + default=Factory(lambda: AnthropicPromptDriver(model="claude-3-opus-20240229")), metadata={"serializable": True} + ) + embedding_driver: VoyageAiEmbeddingDriver = field( + default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")), metadata={"serializable": True} + ) + vector_store_driver: LocalVectorStoreDriver = field( default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=AnthropicPromptDriver(model="claude-3-opus-20240229"), - embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2") - ), - image_query_driver=AnthropicImageQueryDriver(model="claude-3-opus-20240229"), - ) + lambda: LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")) ), - kw_only=True, metadata={"serializable": True}, ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, - ), - kw_only=True, + image_query_driver: AnthropicImageQueryDriver = field( + default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-opus-20240229")), metadata={"serializable": True}, ) diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index afa8e4012..4848eeda8 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -1,17 +1,32 @@ from __future__ import annotations from abc import ABC +from typing import Optional from attr import define, field -from griptape.config import BaseConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig +from griptape.config import BaseConfig +from griptape.drivers import ( + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseVectorStoreDriver, +) from griptape.utils import dict_merge @define class BaseStructureConfig(BaseConfig, ABC): - global_drivers: StructureGlobalDriversConfig = field(kw_only=True, metadata={"serializable": True}) - task_memory: StructureTaskMemoryConfig = field(kw_only=True, metadata={"serializable": True}) + prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) + image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) + image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) + embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) + vector_store_driver: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) + conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + default=None, kw_only=True, metadata={"serializable": True} + ) def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_structure_config.py index 9ba40622f..53f7bc35a 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_structure_config.py @@ -1,48 +1,20 @@ -from attrs import Factory, define, field +from attrs import define, field, Factory -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) -from griptape.drivers import LocalVectorStoreDriver, GooglePromptDriver, GoogleEmbeddingDriver +from griptape.config import StructureConfig +from griptape.drivers import GoogleEmbeddingDriver, GooglePromptDriver, LocalVectorStoreDriver @define -class GoogleStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=GooglePromptDriver(model="gemini-pro"), - embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001") - ), - ) - ), - kw_only=True, - metadata={"serializable": True}, +class GoogleStructureConfig(StructureConfig): + prompt_driver: GooglePromptDriver = field( + default=Factory(lambda: GooglePromptDriver(model="gemini-pro")), metadata={"serializable": True} + ) + embedding_driver: GoogleEmbeddingDriver = field( + default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), metadata={"serializable": True} ) - task_memory: StructureTaskMemoryConfig = field( + vector_store_driver: LocalVectorStoreDriver = field( default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, + lambda: LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")) ), - kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 64c32ecec..0a06f6cf5 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -1,15 +1,6 @@ -from attrs import Factory, define, field +from attrs import define, Factory, field -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) +from griptape.config import StructureConfig from griptape.drivers import ( LocalVectorStoreDriver, OpenAiChatPromptDriver, @@ -20,37 +11,24 @@ @define -class OpenAiStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), - image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"), - image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"), - embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small") - ), - ) - ), - kw_only=True, +class OpenAiStructureConfig(StructureConfig): + prompt_driver: OpenAiChatPromptDriver = field( + default=Factory(lambda: OpenAiChatPromptDriver(model="gpt-4o")), metadata={"serializable": True} + ) + image_generation_driver: OpenAiImageGenerationDriver = field( + default=Factory(lambda: OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512")), metadata={"serializable": True}, ) - task_memory: StructureTaskMemoryConfig = field( + image_query_driver: OpenAiVisionImageQueryDriver = field( + default=Factory(lambda: OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview")), + metadata={"serializable": True}, + ) + embedding_driver: OpenAiEmbeddingDriver = field( + default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} + ) + vector_store_driver: LocalVectorStoreDriver = field( default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, + lambda: LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")) ), - kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index 6381450f4..363bc6034 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -1,38 +1,40 @@ from attrs import Factory, define, field +from typing import Optional -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, +from griptape.config import BaseStructureConfig + +from griptape.drivers import ( + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BasePromptDriver, + BaseVectorStoreDriver, + DummyVectorStoreDriver, + DummyEmbeddingDriver, + DummyImageGenerationDriver, + DummyPromptDriver, + DummyImageQueryDriver, + BaseImageQueryDriver, ) -from griptape.drivers import LocalVectorStoreDriver @define class StructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory(lambda: StructureGlobalDriversConfig()), kw_only=True, metadata={"serializable": True} + prompt_driver: BasePromptDriver = field( + kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} + ) + image_generation_driver: BaseImageGenerationDriver = field( + kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True} + ) + image_query_driver: BaseImageQueryDriver = field( + kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True} + ) + embedding_driver: BaseEmbeddingDriver = field( + kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True} + ) + vector_store_driver: BaseVectorStoreDriver = field( + default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True} ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, + conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + default=None, kw_only=True, metadata={"serializable": True} ) diff --git a/griptape/config/structure_task_memory_config.py b/griptape/config/structure_task_memory_config.py deleted file mode 100644 index 3b8648dcf..000000000 --- a/griptape/config/structure_task_memory_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from attrs import Factory, define, field - -from griptape.config import ( - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryConfig(SerializableMixin): - query_engine: StructureTaskMemoryQueryEngineConfig = field( - kw_only=True, default=Factory(lambda: StructureTaskMemoryQueryEngineConfig()), metadata={"serializable": True} - ) - extraction_engine: StructureTaskMemoryExtractionEngineConfig = field( - kw_only=True, - default=Factory(lambda: StructureTaskMemoryExtractionEngineConfig()), - metadata={"serializable": True}, - ) - summary_engine: StructureTaskMemorySummaryEngineConfig = field( - kw_only=True, default=Factory(lambda: StructureTaskMemorySummaryEngineConfig()), metadata={"serializable": True} - ) diff --git a/griptape/config/structure_task_memory_extraction_engine_config.py b/griptape/config/structure_task_memory_extraction_engine_config.py deleted file mode 100644 index 8c2a58f02..000000000 --- a/griptape/config/structure_task_memory_extraction_engine_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from attrs import Factory, define, field - -from griptape.config import StructureTaskMemoryExtractionEngineCsvConfig, StructureTaskMemoryExtractionEngineJsonConfig -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryExtractionEngineConfig(SerializableMixin): - csv: StructureTaskMemoryExtractionEngineCsvConfig = field( - kw_only=True, - default=Factory(lambda: StructureTaskMemoryExtractionEngineCsvConfig()), - metadata={"serializable": True}, - ) - json: StructureTaskMemoryExtractionEngineJsonConfig = field( - kw_only=True, - default=Factory(lambda: StructureTaskMemoryExtractionEngineJsonConfig()), - metadata={"serializable": True}, - ) diff --git a/griptape/config/structure_task_memory_extraction_engine_csv_config.py b/griptape/config/structure_task_memory_extraction_engine_csv_config.py deleted file mode 100644 index cce5f3e02..000000000 --- a/griptape/config/structure_task_memory_extraction_engine_csv_config.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import define, field, Factory - -from griptape.drivers import BasePromptDriver, DummyPromptDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryExtractionEngineCsvConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) diff --git a/griptape/config/structure_task_memory_extraction_engine_json_config.py b/griptape/config/structure_task_memory_extraction_engine_json_config.py deleted file mode 100644 index 04210b8c8..000000000 --- a/griptape/config/structure_task_memory_extraction_engine_json_config.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import define, field, Factory - -from griptape.drivers import BasePromptDriver, DummyPromptDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryExtractionEngineJsonConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) diff --git a/griptape/config/structure_task_memory_query_engine_config.py b/griptape/config/structure_task_memory_query_engine_config.py deleted file mode 100644 index 30d6bbcf9..000000000 --- a/griptape/config/structure_task_memory_query_engine_config.py +++ /dev/null @@ -1,22 +0,0 @@ -from attrs import Factory, define, field - -from griptape.drivers import ( - BasePromptDriver, - BaseVectorStoreDriver, - DummyVectorStoreDriver, - DummyEmbeddingDriver, - DummyPromptDriver, -) -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryQueryEngineConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) - vector_store_driver: BaseVectorStoreDriver = field( - kw_only=True, - default=Factory(lambda: DummyVectorStoreDriver(embedding_driver=DummyEmbeddingDriver())), - metadata={"serializable": True}, - ) diff --git a/griptape/config/structure_task_memory_summary_engine_config.py b/griptape/config/structure_task_memory_summary_engine_config.py deleted file mode 100644 index 100f9d8f1..000000000 --- a/griptape/config/structure_task_memory_summary_engine_config.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import Factory, define, field - -from griptape.drivers import BasePromptDriver, DummyPromptDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemorySummaryEngineConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 1f622b0d2..fa06a3c76 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -25,7 +25,7 @@ class SummaryConversationMemory(ConversationMemory): def prompt_driver(self) -> BasePromptDriver: if self._prompt_driver is None: if self.structure is not None: - self._prompt_driver = self.structure.config.global_drivers.prompt_driver + self._prompt_driver = self.structure.config.prompt_driver else: raise ValueError("Prompt Driver is not set.") return self._prompt_driver diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index e5cd73eb8..10912d6ae 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -11,9 +11,9 @@ @define class TextArtifactStorage(BaseArtifactStorage): query_engine: VectorQueryEngine = field(kw_only=True) - summary_engine: BaseSummaryEngine = field(kw_only=True) - csv_extraction_engine: CsvExtractionEngine = field(kw_only=True) - json_extraction_engine: JsonExtractionEngine = field(kw_only=True) + summary_engine: Optional[BaseSummaryEngine] = field(kw_only=True, default=None) + csv_extraction_engine: Optional[CsvExtractionEngine] = field(kw_only=True, default=None) + json_extraction_engine: Optional[JsonExtractionEngine] = field(kw_only=True, default=None) def can_store(self, artifact: BaseArtifact) -> bool: return isinstance(artifact, TextArtifact) @@ -28,6 +28,8 @@ def load_artifacts(self, namespace: str) -> ListArtifact: return self.query_engine.load_artifacts(namespace) def summarize(self, namespace: str) -> TextArtifact: + if self.summary_engine is None: + raise ValueError("Summary engine is not set.") return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace)) def query(self, namespace: str, query: str, metadata: Any = None) -> TextArtifact: diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index c3912649d..feda77543 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -55,8 +55,7 @@ class Structure(ABC): event_listeners: list[EventListener] = field(factory=list, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( default=Factory( - lambda self: ConversationMemory(driver=self.config.global_drivers.conversation_memory_driver), - takes_self=True, + lambda self: ConversationMemory(driver=self.config.conversation_memory_driver), takes_self=True ), kw_only=True, ) @@ -97,19 +96,17 @@ def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]: @prompt_driver.validator # pyright: ignore def validate_prompt_driver(self, attribute, value): if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.global_drivers.prompt_driver` instead.") + deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver` instead.") @embedding_driver.validator # pyright: ignore def validate_embedding_driver(self, attribute, value): if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.global_drivers.embedding_driver` instead.") + deprecation_warn(f"`{attribute.name}` is deprecated, use `config.embedding_driver` instead.") @stream.validator # pyright: ignore def validate_stream(self, attribute, value): if value is not None: - deprecation_warn( - f"`{attribute.name}` is deprecated, use `config.global_drivers.prompt_driver.stream` instead." - ) + deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver.stream` instead.") @property def execution_args(self) -> tuple: @@ -161,15 +158,9 @@ def default_config(self) -> BaseStructureConfig: vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) - config.global_drivers.prompt_driver = prompt_driver - config.global_drivers.vector_store_driver = vector_store_driver - config.global_drivers.embedding_driver = embedding_driver - - config.task_memory.query_engine.prompt_driver = prompt_driver - config.task_memory.query_engine.vector_store_driver = vector_store_driver - config.task_memory.summary_engine.prompt_driver = prompt_driver - config.task_memory.extraction_engine.csv.prompt_driver = prompt_driver - config.task_memory.extraction_engine.json.prompt_driver = prompt_driver + config.prompt_driver = prompt_driver + config.vector_store_driver = vector_store_driver + config.embedding_driver = embedding_driver else: config = OpenAiStructureConfig() @@ -177,45 +168,15 @@ def default_config(self) -> BaseStructureConfig: @property def default_task_memory(self) -> TaskMemory: - global_drivers = self.config.global_drivers - task_memory = self.config.task_memory - return TaskMemory( artifact_storages={ TextArtifact: TextArtifactStorage( query_engine=VectorQueryEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.query_engine.prompt_driver, DummyPromptDriver) - else task_memory.query_engine.prompt_driver - ), - vector_store_driver=( - global_drivers.vector_store_driver - if isinstance(task_memory.query_engine.prompt_driver, DummyVectorStoreDriver) - else task_memory.query_engine.vector_store_driver - ), - ), - summary_engine=PromptSummaryEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.summary_engine.prompt_driver, DummyPromptDriver) - else task_memory.summary_engine.prompt_driver - ) - ), - csv_extraction_engine=CsvExtractionEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.extraction_engine.csv.prompt_driver, DummyPromptDriver) - else task_memory.extraction_engine.csv.prompt_driver - ) - ), - json_extraction_engine=JsonExtractionEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.extraction_engine.json.prompt_driver, DummyPromptDriver) - else task_memory.extraction_engine.json.prompt_driver - ) + prompt_driver=self.config.prompt_driver, vector_store_driver=self.config.vector_store_driver ), + summary_engine=PromptSummaryEngine(prompt_driver=self.config.prompt_driver), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=self.config.prompt_driver), + json_extraction_engine=JsonExtractionEngine(prompt_driver=self.config.prompt_driver), ), BlobArtifact: BlobArtifactStorage(), } diff --git a/griptape/tasks/csv_extraction_task.py b/griptape/tasks/csv_extraction_task.py index ae71ea5d4..2f5f3db56 100644 --- a/griptape/tasks/csv_extraction_task.py +++ b/griptape/tasks/csv_extraction_task.py @@ -12,9 +12,7 @@ class CsvExtractionTask(ExtractionTask): def extraction_engine(self) -> CsvExtractionEngine: if self._extraction_engine is None: if self.structure is not None: - self._extraction_engine = CsvExtractionEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver - ) + self._extraction_engine = CsvExtractionEngine(prompt_driver=self.structure.config.prompt_driver) else: raise ValueError("Extraction Engine is not set.") return self._extraction_engine diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index 29077a055..94be4f483 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -58,9 +58,7 @@ def input( def image_query_engine(self) -> ImageQueryEngine: if self._image_query_engine is None: if self.structure is not None: - self._image_query_engine = ImageQueryEngine( - image_query_driver=self.structure.config.global_drivers.image_query_driver - ) + self._image_query_engine = ImageQueryEngine(image_query_driver=self.structure.config.image_query_driver) else: raise ValueError("Image Query Engine is not set.") return self._image_query_engine diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 028ae336e..f3b2edb7a 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -57,7 +57,7 @@ def image_generation_engine(self) -> InpaintingImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = InpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/tasks/json_extraction_task.py b/griptape/tasks/json_extraction_task.py index a43b1e1e2..e1f082fd8 100644 --- a/griptape/tasks/json_extraction_task.py +++ b/griptape/tasks/json_extraction_task.py @@ -12,9 +12,7 @@ class JsonExtractionTask(ExtractionTask): def extraction_engine(self) -> JsonExtractionEngine: if self._extraction_engine is None: if self.structure is not None: - self._extraction_engine = JsonExtractionEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver - ) + self._extraction_engine = JsonExtractionEngine(prompt_driver=self.structure.config.prompt_driver) else: raise ValueError("Extraction Engine is not set.") return self._extraction_engine diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index 7b0ffd06b..fd2d335e5 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -57,7 +57,7 @@ def image_generation_engine(self) -> OutpaintingImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = OutpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 24971d1cb..93404ef84 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -50,7 +50,7 @@ def image_generation_engine(self) -> PromptImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = PromptImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 525e21eef..16f7c6dac 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -42,7 +42,7 @@ def prompt_stack(self) -> PromptStack: def prompt_driver(self) -> BasePromptDriver: if self._prompt_driver is None: if self.structure is not None: - self._prompt_driver = self.structure.config.global_drivers.prompt_driver + self._prompt_driver = self.structure.config.prompt_driver else: raise ValueError("Prompt Driver is not set") return self._prompt_driver diff --git a/griptape/tasks/text_query_task.py b/griptape/tasks/text_query_task.py index 88fb94d10..5fee68103 100644 --- a/griptape/tasks/text_query_task.py +++ b/griptape/tasks/text_query_task.py @@ -18,8 +18,8 @@ def query_engine(self) -> BaseQueryEngine: if self._query_engine is None: if self.structure is not None: self._query_engine = VectorQueryEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver, - vector_store_driver=self.structure.config.global_drivers.vector_store_driver, + prompt_driver=self.structure.config.prompt_driver, + vector_store_driver=self.structure.config.vector_store_driver, ) else: raise ValueError("Query Engine is not set.") diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index 7cb6a8a0b..f10f851d0 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -17,9 +17,7 @@ class TextSummaryTask(BaseTextInputTask): def summary_engine(self) -> Optional[BaseSummaryEngine]: if self._summary_engine is None: if self.structure is not None: - self._summary_engine = PromptSummaryEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver - ) + self._summary_engine = PromptSummaryEngine(prompt_driver=self.structure.config.prompt_driver) else: raise ValueError("Summary Engine is not set.") return self._summary_engine diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index 650d395c3..1242bc59b 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -56,7 +56,7 @@ def image_generation_engine(self) -> VariationImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = VariationImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 110a427da..549d93e53 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -21,7 +21,7 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - if self.structure.config.global_drivers.prompt_driver.stream: + if self.structure.config.prompt_driver.stream: print(text, end="", flush=True) else: print(text) @@ -36,7 +36,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if self.structure.config.global_drivers.prompt_driver.stream: + if self.structure.config.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/prompt_stack.py b/griptape/utils/prompt_stack.py index f4dda1462..a5f336030 100644 --- a/griptape/utils/prompt_stack.py +++ b/griptape/utils/prompt_stack.py @@ -66,7 +66,7 @@ def add_conversation_memory(self, memory: BaseConversationMemory, index: Optiona if memory.autoprune and hasattr(memory, "structure"): should_prune = True - prompt_driver = memory.structure.config.global_drivers.prompt_driver + prompt_driver = memory.structure.config.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index a0251fa57..8cb2c3a7c 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -33,7 +33,7 @@ class Stream: @structure.validator # pyright: ignore def validate_structure(self, _, structure: Structure): - if structure and not structure.config.global_drivers.prompt_driver.stream: + if structure and not structure.config.prompt_driver.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_structure_config.py index 2ac901009..8309f541b 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_structure_config.py @@ -1,15 +1,5 @@ from attrs import define, field, Factory -from griptape.drivers import LocalVectorStoreDriver -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemorySummaryEngineConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryExtractionEngineCsvConfig, -) +from griptape.config import StructureConfig from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_image_query_driver import MockImageQueryDriver from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -17,41 +7,16 @@ @define -class MockStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=MockPromptDriver(), - image_generation_driver=MockImageGenerationDriver(model="dall-e-2"), - image_query_driver=MockImageQueryDriver(model="gpt-4-vision-preview"), - embedding_driver=MockEmbeddingDriver(model="text-embedding-3-small"), - ) - ), - kw_only=True, - metadata={"serializable": True}, +class MockStructureConfig(StructureConfig): + prompt_driver: MockPromptDriver = field( + default=Factory(lambda: MockPromptDriver()), metadata={"serializable": True} ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=MockPromptDriver(model="gpt-3.5-turbo"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(model="text-embedding-3-small") - ), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig( - prompt_driver=MockPromptDriver(model="gpt-3.5-turbo") - ), - json=StructureTaskMemoryExtractionEngineJsonConfig( - prompt_driver=MockPromptDriver(model="gpt-3.5-turbo") - ), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig( - prompt_driver=MockPromptDriver(model="gpt-3.5-turbo") - ), - ) - ), - kw_only=True, - metadata={"serializable": True}, + image_generation_driver: MockImageGenerationDriver = field( + default=Factory(lambda: MockImageGenerationDriver(model="dall-e-2")), metadata={"serializable": True} + ) + image_query_driver: MockImageQueryDriver = field( + default=Factory(lambda: MockImageQueryDriver(model="gpt-4-vision-preview")), metadata={"serializable": True} + ) + embedding_driver: MockEmbeddingDriver = field( + default=Factory(lambda: MockEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} ) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index 05eda5644..d787897cf 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -14,115 +14,43 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { - "global_drivers": { - "conversation_memory_driver": None, + "conversation_memory_driver": None, + "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation_driver": { + "image_generation_model_driver": { + "cfg_scale": 7, + "outpainting_mode": "PRECISE", + "quality": "standard", + "type": "BedrockTitanImageGenerationModelDriver", + }, + "image_height": 512, + "image_width": 512, + "model": "amazon.titan-image-generator-v1", + "seed": None, + "type": "AmazonBedrockImageGenerationDriver", + }, + "image_query_driver": { + "type": "AmazonBedrockImageQueryDriver", + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "max_tokens": 256, + "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, + }, + "prompt_driver": { + "max_tokens": None, + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, + "stream": False, + "temperature": 0.1, + "type": "AmazonBedrockPromptDriver", + }, + "vector_store_driver": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, - "image_generation_driver": { - "image_generation_model_driver": { - "cfg_scale": 7, - "outpainting_mode": "PRECISE", - "quality": "standard", - "type": "BedrockTitanImageGenerationModelDriver", - }, - "image_height": 512, - "image_width": 512, - "model": "amazon.titan-image-generator-v1", - "seed": None, - "type": "AmazonBedrockImageGenerationDriver", - }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "image_query_driver": { - "type": "AmazonBedrockImageQueryDriver", - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "max_tokens": 256, - "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, - }, - "prompt_driver": { - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, - "stream": False, - "temperature": 0.1, - "type": "AmazonBedrockPromptDriver", - }, - "type": "StructureGlobalDriversConfig", - "vector_store_driver": { - "embedding_driver": { - "model": "amazon.titan-embed-text-v1", - "type": "AmazonBedrockTitanEmbeddingDriver", - }, - "type": "LocalVectorStoreDriver", - }, + "type": "LocalVectorStoreDriver", }, "type": "AmazonBedrockStructureConfig", - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, - "stream": False, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "AmazonBedrockTitanEmbeddingDriver", - "model": "amazon.titan-embed-text-v1", - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": { - "type": "BedrockClaudePromptModelDriver", - "top_k": 250, - "top_p": 0.999, - }, - "stream": False, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": { - "type": "BedrockClaudePromptModelDriver", - "top_k": 250, - "top_p": 0.999, - }, - "stream": False, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, - "stream": False, - }, - }, - }, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index b637eb929..596365eb2 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -15,101 +15,35 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "AnthropicStructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "image_query_driver": { - "type": "AnthropicImageQueryDriver", - "model": "claude-3-opus-20240229", - "max_tokens": 256, - }, + "prompt_driver": { + "type": "AnthropicPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "model": "claude-3-opus-20240229", + "top_p": 0.999, + "top_k": 250, + }, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": { + "type": "AnthropicImageQueryDriver", + "model": "claude-3-opus-20240229", + "max_tokens": 256, + }, + "embedding_driver": { + "type": "VoyageAiEmbeddingDriver", + "model": "voyage-large-2", + "input_type": "document", + }, + "vector_store_driver": { + "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", "input_type": "document", }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "VoyageAiEmbeddingDriver", - "model": "voyage-large-2", - "input_type": "document", - }, - }, - "conversation_memory_driver": None, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "VoyageAiEmbeddingDriver", - "model": "voyage-large-2", - "input_type": "document", - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - }, }, + "conversation_memory_driver": None, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index 2a4d50641..dcca9e29d 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -14,100 +14,33 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "GoogleStructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, + "prompt_driver": { + "type": "GooglePromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "model": "gemini-pro", + "top_p": None, + "top_k": None, + }, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "embedding_driver": { + "type": "GoogleEmbeddingDriver", + "model": "models/embedding-001", + "task_type": "retrieval_document", + "title": None, + }, + "vector_store_driver": { + "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", "task_type": "retrieval_document", "title": None, }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "GoogleEmbeddingDriver", - "model": "models/embedding-001", - "task_type": "retrieval_document", - "title": None, - }, - }, - "conversation_memory_driver": None, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "GoogleEmbeddingDriver", - "model": "models/embedding-001", - "task_type": "retrieval_document", - "title": None, - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - }, }, + "conversation_memory_driver": None, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index 4a0ab7369..a2df52216 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -14,132 +14,53 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "OpenAiStructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4o", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - "conversation_memory_driver": None, + "prompt_driver": { + "type": "OpenAiChatPromptDriver", + "base_url": None, + "model": "gpt-4", + "organization": None, + "response_format": None, + "seed": None, + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "user": "", + }, + "conversation_memory_driver": None, + "embedding_driver": { + "base_url": None, + "model": "text-embedding-3-small", + "organization": None, + "type": "OpenAiEmbeddingDriver", + }, + "image_generation_driver": { + "api_version": None, + "base_url": None, + "image_size": "512x512", + "model": "dall-e-2", + "organization": None, + "quality": "standard", + "response_format": "b64_json", + "style": None, + "type": "OpenAiImageGenerationDriver", + }, + "image_query_driver": { + "api_version": None, + "base_url": None, + "image_quality": "auto", + "max_tokens": 256, + "model": "gpt-4-vision-preview", + "organization": None, + "type": "OpenAiVisionImageQueryDriver", + }, + "vector_store_driver": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", "organization": None, "type": "OpenAiEmbeddingDriver", }, - "image_generation_driver": { - "api_version": None, - "base_url": None, - "image_size": "512x512", - "model": "dall-e-2", - "organization": None, - "quality": "standard", - "response_format": "b64_json", - "style": None, - "type": "OpenAiImageGenerationDriver", - }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "image_query_driver": { - "api_version": None, - "base_url": None, - "image_quality": "auto", - "max_tokens": 256, - "model": "gpt-4-vision-preview", - "organization": None, - "type": "OpenAiVisionImageQueryDriver", - }, - "vector_store_driver": { - "embedding_driver": { - "base_url": None, - "model": "text-embedding-3-small", - "organization": None, - "type": "OpenAiEmbeddingDriver", - }, - "type": "LocalVectorStoreDriver", - }, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "base_url": None, - "type": "OpenAiChatPromptDriver", - "model": "gpt-4o", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "OpenAiEmbeddingDriver", - "base_url": None, - "organization": None, - "model": "text-embedding-3-small", - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4o", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4o", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4o", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - }, + "type": "LocalVectorStoreDriver", }, } diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 16f189b02..5cc3e2561 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -10,64 +10,14 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "StructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}, - "conversation_memory_driver": None, + "prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}, + "conversation_memory_driver": None, + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "vector_store_driver": { "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "vector_store_driver": { - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "type": "DummyVectorStoreDriver", - }, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "stream": False, - "temperature": 0.1, - "max_tokens": None, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, + "type": "DummyVectorStoreDriver", }, } @@ -79,19 +29,11 @@ def test_unchanged_merge_config(self, config): config.merge_config( { "type": "StructureConfig", - "task_memory": { - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, - } + "prompt_driver": { + "type": "DummyPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, }, } ).to_dict() @@ -100,12 +42,12 @@ def test_unchanged_merge_config(self, config): def test_changed_merge_config(self, config): config = config.merge_config( - {"task_memory": {"extraction_engine": {"csv": {"prompt_driver": {"stream": True}}}}} + {"prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}} ) - assert config.task_memory.extraction_engine.csv.prompt_driver.stream is True + assert config.prompt_driver.temperature == 0.1 def test_dot_update(self, config): - config.task_memory.extraction_engine.csv.prompt_driver.stream = True + config.prompt_driver.max_tokens = 10 - assert config.task_memory.extraction_engine.csv.prompt_driver.stream is True + assert config.prompt_driver.max_tokens == 10 diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 1ffa24ecd..0366652b0 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -14,10 +14,8 @@ def task(self): def test_run(self, task): mock_config = MockStructureConfig() - assert isinstance(mock_config.global_drivers.prompt_driver, MockPromptDriver) - mock_config.global_drivers.prompt_driver.mock_output = ( - '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - ) + assert isinstance(mock_config.prompt_driver, MockPromptDriver) + mock_config.prompt_driver.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' agent = Agent(config=mock_config) agent.add_task(task) From 9a10618ce6ece7321aa8ca40369ff1a4b5341459 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 30 Apr 2024 16:19:22 -0700 Subject: [PATCH 2/9] Update docs --- docs/examples/multiple-agent-shared-memory.md | 10 +- .../drivers/embedding-drivers.md | 6 +- .../drivers/prompt-drivers.md | 188 ++++++++---------- docs/griptape-framework/structures/config.md | 49 +---- .../official-tools/rest-api-client.md | 12 +- 5 files changed, 99 insertions(+), 166 deletions(-) diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index cf3966c97..ac69abcd4 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -15,7 +15,7 @@ from griptape.engines import VectorQueryEngine, PromptSummaryEngine, CsvExtracti from griptape.memory import TaskMemory from griptape.artifacts import TextArtifact from griptape.memory.task.storage import TextArtifactStorage -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] @@ -59,11 +59,9 @@ loader = Agent( WebScraper() ], config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=azure_prompt_driver, - vector_store_driver=mongo_driver, - embedding_driver=azure_embedding_driver - ) + prompt_driver=azure_prompt_driver, + vector_store_driver=mongo_driver, + embedding_driver=azure_embedding_driver ), ) asker = Agent( diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 332e9674a..2b8824668 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -182,10 +182,8 @@ from griptape.config import ( agent = Agent( tools=[WebScraper(), TaskMemoryClient(off_prompt=False)], config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), - embedding_driver=VoyageAiEmbeddingDriver(), - ) + prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), + embedding_driver=VoyageAiEmbeddingDriver(), ), ) diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index db6bf5d96..848d60115 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -8,13 +8,11 @@ You can instantiate drivers and pass them to structures: from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4", temperature=0.3), - ) + prompt_driver=OpenAiChatPromptDriver(model="gpt-4", temperature=0.3), ), input_template="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[ @@ -70,18 +68,16 @@ import os from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver( - api_key=os.environ["OPENAI_API_KEY"], - temperature=0.1, - model="gpt-3.5-turbo", - response_format="json_object", - seed=42, - ) + prompt_driver=OpenAiChatPromptDriver( + api_key=os.environ["OPENAI_API_KEY"], + temperature=0.1, + model="gpt-3.5-turbo", + response_format="json_object", + seed=42, ) ), input_template="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}", @@ -107,17 +103,15 @@ import os from griptape.structures import Agent from griptape.rules import Rule from griptape.drivers import AzureOpenAiChatPromptDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-3.5-turbo-16k", - azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - ) + prompt_driver=AzureOpenAiChatPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="gpt-3.5-turbo-16k", + azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], ) ), rules=[ @@ -139,18 +133,16 @@ The [AzureOpenAiCompletionPromptDriver](../../reference/griptape/drivers/prompt/ import os from griptape.structures import Agent from griptape.drivers import AzureOpenAiCompletionPromptDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AzureOpenAiCompletionPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="text-davinci-003", - azure_deployment=os.environ["AZURE_OPENAI_DAVINCI_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - temperature=1 - ) + prompt_driver=AzureOpenAiCompletionPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="text-davinci-003", + azure_deployment=os.environ["AZURE_OPENAI_DAVINCI_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], + temperature=1 ) ) ) @@ -176,15 +168,13 @@ The [CoherePromptDriver](../../reference/griptape/drivers/prompt/cohere_prompt_d import os from griptape.structures import Agent from griptape.drivers import CoherePromptDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=CoherePromptDriver( - model="command", - api_key=os.environ['COHERE_API_KEY'], - ) + prompt_driver=CoherePromptDriver( + model="command", + api_key=os.environ['COHERE_API_KEY'], ) ) ) @@ -203,15 +193,13 @@ The [AnthropicPromptDriver](../../reference/griptape/drivers/prompt/anthropic_pr import os from griptape.structures import Agent from griptape.drivers import AnthropicPromptDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AnthropicPromptDriver( - model="claude-3-opus-20240229", - api_key=os.environ['ANTHROPIC_API_KEY'], - ) + prompt_driver=AnthropicPromptDriver( + model="claude-3-opus-20240229", + api_key=os.environ['ANTHROPIC_API_KEY'], ) ) ) @@ -230,15 +218,13 @@ The [GooglePromptDriver](../../reference/griptape/drivers/prompt/google_prompt_d import os from griptape.structures import Agent from griptape.drivers import GooglePromptDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=GooglePromptDriver( - model="gemini-pro", - api_key=os.environ['GOOGLE_API_KEY'], - ) + prompt_driver=GooglePromptDriver( + model="gemini-pro", + api_key=os.environ['GOOGLE_API_KEY'], ) ) ) @@ -264,7 +250,7 @@ from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset from griptape.utils import PromptStack -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str: @@ -284,12 +270,10 @@ def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str: agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="tiiuae/falcon-7b-instruct", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - prompt_stack_to_string=prompt_stack_to_string_converter, - ) + prompt_driver=HuggingFaceHubPromptDriver( + model="tiiuae/falcon-7b-instruct", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], + prompt_stack_to_string=prompt_stack_to_string_converter, ) ), rulesets=[ @@ -317,17 +301,15 @@ The [HuggingFaceHubPromptDriver](#hugging-face-hub) also supports [Text Generati import os from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="http://127.0.0.1:8080", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - ), - ) + prompt_driver=HuggingFaceHubPromptDriver( + model="http://127.0.0.1:8080", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], + ), ), ) @@ -353,7 +335,7 @@ from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset from griptape.utils import PromptStack -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig # Override the default Prompt Stack to string converter @@ -375,12 +357,10 @@ def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str: agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="tiiuae/falcon-7b-instruct", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - prompt_stack_to_string=prompt_stack_to_string_converter, - ), + prompt_driver=HuggingFaceHubPromptDriver( + model="tiiuae/falcon-7b-instruct", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], + prompt_stack_to_string=prompt_stack_to_string_converter, ) ), rulesets=[ @@ -423,16 +403,14 @@ from griptape.drivers import ( SageMakerLlamaPromptModelDriver, ) from griptape.rules import Rule -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AmazonSageMakerPromptDriver( - model=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"], - prompt_model_driver=SageMakerLlamaPromptModelDriver(), - temperature=0.75, - ), + prompt_driver=AmazonSageMakerPromptDriver( + model=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"], + prompt_model_driver=SageMakerLlamaPromptModelDriver(), + temperature=0.75, ) ), rules=[ @@ -455,15 +433,13 @@ from griptape.drivers import ( AmazonSageMakerPromptDriver, SageMakerFalconPromptModelDriver, ) -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AmazonSageMakerPromptDriver( - model=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"], - prompt_model_driver=SageMakerFalconPromptModelDriver(), - ), + prompt_driver=AmazonSageMakerPromptDriver( + model=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"], + prompt_model_driver=SageMakerFalconPromptModelDriver(), ) ) ) @@ -486,16 +462,14 @@ To use this model with Amazon Bedrock, use the [BedrockTitanPromptModelDriver](. ```python from griptape.structures import Agent from griptape.drivers import AmazonBedrockPromptDriver, BedrockTitanPromptModelDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="amazon.titan-text-express-v1", - prompt_model_driver=BedrockTitanPromptModelDriver( - top_p=1, - ) + prompt_driver=AmazonBedrockPromptDriver( + model="amazon.titan-text-express-v1", + prompt_model_driver=BedrockTitanPromptModelDriver( + top_p=1, ) ) ) @@ -515,17 +489,15 @@ To use this model with Amazon Bedrock, use the [BedrockClaudePromptModelDriver]( from griptape.structures import Agent from griptape.drivers import AmazonBedrockPromptDriver, BedrockClaudePromptModelDriver from griptape.rules import Rule -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - prompt_model_driver=BedrockClaudePromptModelDriver( - top_p=1, - ), - ), + prompt_driver=AmazonBedrockPromptDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", + prompt_model_driver=BedrockClaudePromptModelDriver( + top_p=1, + ) ) ), rules=[ @@ -554,15 +526,13 @@ To use this model with Amazon Bedrock, use the [BedrockLlamaPromptModelDriver](. ```python from griptape.structures import Agent from griptape.drivers import AmazonBedrockPromptDriver, BedrockLlamaPromptModelDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="meta.llama2-13b-chat-v1", - prompt_model_driver=BedrockLlamaPromptModelDriver(), - ), + prompt_driver=AmazonBedrockPromptDriver( + model="meta.llama2-13b-chat-v1", + prompt_model_driver=BedrockLlamaPromptModelDriver(), ) ) ) @@ -578,16 +548,14 @@ To use this model with Amazon Bedrock, use the [BedrockJurassicPromptModelDriver ```python from griptape.structures import Agent from griptape.drivers import AmazonBedrockPromptDriver, BedrockJurassicPromptModelDriver -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="ai21.j2-ultra-v1", - prompt_model_driver=BedrockJurassicPromptModelDriver(top_p=0.95), - temperature=0.7, - ) + prompt_driver=AmazonBedrockPromptDriver( + model="ai21.j2-ultra-v1", + prompt_model_driver=BedrockJurassicPromptModelDriver(top_p=0.95), + temperature=0.7, ) ) ) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index b54c426d8..817dfb7fe 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -74,44 +74,19 @@ This approach ensures that you are informed through clear error messages if you ```python import os from griptape.structures import Agent -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig from griptape.drivers import AnthropicPromptDriver agent = Agent( config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=AnthropicPromptDriver( - model="claude-3-sonnet-20240229", - api_key=os.environ["ANTHROPIC_API_KEY"], - ) + prompt_driver=AnthropicPromptDriver( + model="claude-3-sonnet-20240229", + api_key=os.environ["ANTHROPIC_API_KEY"], ) ), ) ``` -### Task Memory - -Griptape allows for detailed control over [Task Memory](./task-memory.md) settings, permitting overrides on a per Engine basis, beyond the global Drivers configuration. - -```python -from griptape.structures import Agent -from griptape.config import StructureConfig, StructureTaskMemoryConfig, StructureTaskMemoryQueryEngineConfig -from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver - - -agent = Agent( - config=StructureConfig( - task_memory=StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=OpenAiEmbeddingDriver(), - ) - ) - ) - ) -) -``` - ### Loading/Saving Configs Configuration classes in Griptape offer utility methods for loading, saving, and merging configurations, streamlining the management of complex setups. @@ -125,16 +100,12 @@ custom_config = AmazonBedrockStructureConfig() custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() custom_config.merge_config( { - "task_memory": { - "summary_engine": { - "prompt_driver": { - "model": "amazon.titan-text-express-v1", - "prompt_model_driver": { - "type": "BedrockTitanPromptModelDriver", - }, - } - } - } + "embedding_driver": { + "base_url": None, + "model": "text-embedding-3-small", + "organization": None, + "type": "OpenAiEmbeddingDriver", + }, } ) serialized_config = custom_config.to_json() diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index b253d0bce..7c10d615a 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -14,7 +14,7 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig posts_client = RestApiClient( base_url="https://jsonplaceholder.typicode.com", @@ -119,12 +119,10 @@ posts_client = RestApiClient( pipeline = Pipeline( conversation_memory=ConversationMemory(), config = StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver( - model="gpt-4", - temperature=0.1 - ), - ) + prompt_driver=OpenAiChatPromptDriver( + model="gpt-4", + temperature=0.1 + ), ), ) From 39660420b8f9883f5ad092ccd58f7a0873493777 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 30 Apr 2024 17:18:42 -0700 Subject: [PATCH 3/9] Clean up imports --- griptape/structures/structure.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index feda77543..23ce6584c 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -11,14 +11,7 @@ from griptape.artifacts import BlobArtifact, TextArtifact from griptape.config import BaseStructureConfig, OpenAiStructureConfig, StructureConfig -from griptape.drivers import ( - BaseEmbeddingDriver, - BasePromptDriver, - DummyPromptDriver, - DummyVectorStoreDriver, - OpenAiEmbeddingDriver, - OpenAiChatPromptDriver, -) +from griptape.drivers import BaseEmbeddingDriver, BasePromptDriver, OpenAiEmbeddingDriver, OpenAiChatPromptDriver from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine, VectorQueryEngine from griptape.events import BaseEvent, EventListener From 04c35cf9e551779fd8fd51309b6443d80db02a83 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 16 May 2024 06:50:21 -1000 Subject: [PATCH 4/9] Update gpt-4 to gpt-4o in some places --- docs/griptape-framework/data/chunkers.md | 2 +- docs/griptape-framework/drivers/embedding-drivers.md | 2 +- docs/griptape-framework/drivers/prompt-drivers.md | 2 +- docs/griptape-framework/misc/tokenizers.md | 2 +- docs/griptape-tools/official-tools/rest-api-client.md | 2 +- griptape/structures/structure.py | 2 +- tests/utils/structure_tester.py | 5 ++++- 7 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/griptape-framework/data/chunkers.md b/docs/griptape-framework/data/chunkers.md index 0df73f965..b67951b2a 100644 --- a/docs/griptape-framework/data/chunkers.md +++ b/docs/griptape-framework/data/chunkers.md @@ -15,7 +15,7 @@ from griptape.chunkers import TextChunker from griptape.tokenizers import OpenAiTokenizer TextChunker( # set an optional custom tokenizer - tokenizer=OpenAiTokenizer(model="gpt-4"), + tokenizer=OpenAiTokenizer(model="gpt-4o"), # optionally modify default number of tokens max_tokens=100 ).chunk("long text") diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 2b8824668..6fe2af60d 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -182,7 +182,7 @@ from griptape.config import ( agent = Agent( tools=[WebScraper(), TaskMemoryClient(off_prompt=False)], config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), embedding_driver=VoyageAiEmbeddingDriver(), ), ) diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 848d60115..ae798b22e 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -12,7 +12,7 @@ from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4", temperature=0.3), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), ), input_template="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[ diff --git a/docs/griptape-framework/misc/tokenizers.md b/docs/griptape-framework/misc/tokenizers.md index b523d04e4..aaf488187 100644 --- a/docs/griptape-framework/misc/tokenizers.md +++ b/docs/griptape-framework/misc/tokenizers.md @@ -13,7 +13,7 @@ Tokenizers are a low level abstraction that you will rarely interact with direct from griptape.tokenizers import OpenAiTokenizer -tokenizer = OpenAiTokenizer(model="gpt-4") +tokenizer = OpenAiTokenizer(model="gpt-4o") print(tokenizer.count_tokens("Hello world!")) print(tokenizer.count_input_tokens_left("Hello world!")) diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 7c10d615a..a889f0e81 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -120,7 +120,7 @@ pipeline = Pipeline( conversation_memory=ConversationMemory(), config = StructureConfig( prompt_driver=OpenAiChatPromptDriver( - model="gpt-4", + model="gpt-4o", temperature=0.1 ), ), diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 23ce6584c..8b71dd905 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -137,7 +137,7 @@ def default_config(self) -> BaseStructureConfig: config = StructureConfig() if self.prompt_driver is None: - prompt_driver = OpenAiChatPromptDriver(model="gpt-4") + prompt_driver = OpenAiChatPromptDriver(model="gpt-4o") else: prompt_driver = self.prompt_driver diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index e9972660f..abd8f0e0a 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -62,6 +62,9 @@ class TesterPromptDriverOption: "OPENAI_CHAT_4": TesterPromptDriverOption( prompt_driver=OpenAiChatPromptDriver(model="gpt-4", api_key=os.environ["OPENAI_API_KEY"]), enabled=True ), + "OPENAI_CHAT_4o": TesterPromptDriverOption( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", api_key=os.environ["OPENAI_API_KEY"]), enabled=True + ), "OPENAI_CHAT_4_1106_PREVIEW": TesterPromptDriverOption( prompt_driver=OpenAiChatPromptDriver(model="gpt-4-1106-preview", api_key=os.environ["OPENAI_API_KEY"]), enabled=True, @@ -275,7 +278,7 @@ def verify_structure_output(self, structure) -> dict: ], prompt_driver=AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-4", + model="gpt-4o", azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], response_format="json_object", From c1e51476358554e6528e7605af5598c5217318ca Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 16 May 2024 06:59:49 -1000 Subject: [PATCH 5/9] Post merge fixes --- .../drivers/event-listener-drivers.md | 6 +-- griptape/config/base_structure_config.py | 2 + griptape/config/structure_config.py | 5 +++ .../config/structure_global_drivers_config.py | 45 ------------------- griptape/tasks/text_to_speech_task.py | 2 +- .../test_amazon_bedrock_structure_config.py | 1 + .../config/test_anthropic_structure_config.py | 1 + .../config/test_google_structure_config.py | 1 + .../config/test_openai_structure_config.py | 3 +- tests/unit/config/test_structure_config.py | 1 + 10 files changed, 16 insertions(+), 51 deletions(-) delete mode 100644 griptape/config/structure_global_drivers_config.py diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index da9a6c05d..f8c4080db 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -134,10 +134,8 @@ agent = Agent( ) ], config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver( - model="gpt-3.5-turbo", temperature=0.7 - ), + prompt_driver=OpenAiChatPromptDriver( + model="gpt-3.5-turbo", temperature=0.7 ) ), event_listeners=[ diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index 4848eeda8..d716205c8 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -13,6 +13,7 @@ BaseImageQueryDriver, BasePromptDriver, BaseVectorStoreDriver, + BaseTextToSpeechDriver, ) from griptape.utils import dict_merge @@ -27,6 +28,7 @@ class BaseStructureConfig(BaseConfig, ABC): conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True} ) + text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index 363bc6034..63f1ea9f3 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -15,6 +15,8 @@ DummyPromptDriver, DummyImageQueryDriver, BaseImageQueryDriver, + BaseTextToSpeechDriver, + DummyTextToSpeechDriver, ) @@ -38,3 +40,6 @@ class StructureConfig(BaseStructureConfig): conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True} ) + text_to_speech_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True} + ) diff --git a/griptape/config/structure_global_drivers_config.py b/griptape/config/structure_global_drivers_config.py deleted file mode 100644 index b599039a2..000000000 --- a/griptape/config/structure_global_drivers_config.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional - -from attrs import Factory, define, field - -from griptape.drivers import ( - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BasePromptDriver, - BaseVectorStoreDriver, - DummyVectorStoreDriver, - DummyEmbeddingDriver, - DummyImageGenerationDriver, - DummyPromptDriver, - DummyImageQueryDriver, - BaseImageQueryDriver, - BaseTextToSpeechDriver, -) -from griptape.drivers.text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureGlobalDriversConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) - image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True} - ) - image_query_driver: BaseImageQueryDriver = field( - kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True} - ) - embedding_driver: BaseEmbeddingDriver = field( - kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True} - ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True} - ) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( - default=None, kw_only=True, metadata={"serializable": True} - ) - text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True} - ) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index ab90b1bbb..8a69227c5 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -37,7 +37,7 @@ def text_to_speech_engine(self) -> TextToSpeechEngine: if self._text_to_speech_engine is None: if self.structure is not None: self._text_to_speech_engine = TextToSpeechEngine( - text_to_speech_driver=self.structure.config.global_drivers.text_to_speech_driver + text_to_speech_driver=self.structure.config.text_to_speech_driver ) else: raise ValueError("Audio Generation Engine is not set.") diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index d787897cf..66ca44bb5 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -51,6 +51,7 @@ def test_to_dict(self, config): "type": "LocalVectorStoreDriver", }, "type": "AmazonBedrockStructureConfig", + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 596365eb2..9f014092a 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -44,6 +44,7 @@ def test_to_dict(self, config): }, }, "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index dcca9e29d..f089b611b 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -41,6 +41,7 @@ def test_to_dict(self, config): }, }, "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index a2df52216..bd8db27cd 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -17,7 +17,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -27,6 +27,7 @@ def test_to_dict(self, config): "user": "", }, "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 5cc3e2561..9e1b00038 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -19,6 +19,7 @@ def test_to_dict(self, config): "embedding_driver": {"type": "DummyEmbeddingDriver"}, "type": "DummyVectorStoreDriver", }, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config): From fc427c77d9411696a617c8500e49f0b9a1c42cb7 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 17 May 2024 10:07:02 -0700 Subject: [PATCH 6/9] Add missing kw_only --- griptape/config/anthropic_structure_config.py | 10 ++++++++-- griptape/config/google_structure_config.py | 7 +++++-- griptape/config/openai_structure_config.py | 9 +++++++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_structure_config.py index 6b147b7fd..bfc37bfa9 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_structure_config.py @@ -12,18 +12,24 @@ @define class AnthropicStructureConfig(StructureConfig): prompt_driver: AnthropicPromptDriver = field( - default=Factory(lambda: AnthropicPromptDriver(model="claude-3-opus-20240229")), metadata={"serializable": True} + default=Factory(lambda: AnthropicPromptDriver(model="claude-3-opus-20240229")), + metadata={"serializable": True}, + kw_only=True, ) embedding_driver: VoyageAiEmbeddingDriver = field( - default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")), metadata={"serializable": True} + default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")), + metadata={"serializable": True}, + kw_only=True, ) vector_store_driver: LocalVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")) ), + kw_only=True, metadata={"serializable": True}, ) image_query_driver: AnthropicImageQueryDriver = field( default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-opus-20240229")), + kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_structure_config.py index 53f7bc35a..b83832b6a 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_structure_config.py @@ -7,14 +7,17 @@ @define class GoogleStructureConfig(StructureConfig): prompt_driver: GooglePromptDriver = field( - default=Factory(lambda: GooglePromptDriver(model="gemini-pro")), metadata={"serializable": True} + default=Factory(lambda: GooglePromptDriver(model="gemini-pro")), kw_only=True, metadata={"serializable": True} ) embedding_driver: GoogleEmbeddingDriver = field( - default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), metadata={"serializable": True} + default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), + kw_only=True, + metadata={"serializable": True}, ) vector_store_driver: LocalVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")) ), + kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 0a06f6cf5..bc0a529cc 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -13,22 +13,27 @@ @define class OpenAiStructureConfig(StructureConfig): prompt_driver: OpenAiChatPromptDriver = field( - default=Factory(lambda: OpenAiChatPromptDriver(model="gpt-4o")), metadata={"serializable": True} + default=Factory(lambda: OpenAiChatPromptDriver(model="gpt-4o")), metadata={"serializable": True}, kw_only=True ) image_generation_driver: OpenAiImageGenerationDriver = field( default=Factory(lambda: OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512")), + kw_only=True, metadata={"serializable": True}, ) image_query_driver: OpenAiVisionImageQueryDriver = field( default=Factory(lambda: OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview")), + kw_only=True, metadata={"serializable": True}, ) embedding_driver: OpenAiEmbeddingDriver = field( - default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} + default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), + metadata={"serializable": True}, + kw_only=True, ) vector_store_driver: LocalVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")) ), + kw_only=True, metadata={"serializable": True}, ) From fbaf17dec0707f788848fa14436e04123d633262 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 17 May 2024 12:42:44 -0700 Subject: [PATCH 7/9] Fix linter --- griptape/memory/task/storage/text_artifact_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 10912d6ae..8e4423f54 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from attr import define, field from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact from griptape.memory.task.storage import BaseArtifactStorage From 1130fb3271af2b6416b029125cd5469f5a7ef322 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 20 May 2024 08:19:35 -0700 Subject: [PATCH 8/9] Fix docs --- docs/griptape-framework/drivers/embedding-drivers.md | 5 +---- docs/griptape-framework/drivers/event-listener-drivers.md | 2 +- docs/griptape-framework/misc/events.md | 7 ++++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 6fe2af60d..876013e7f 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -174,10 +174,7 @@ from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, ) -from griptape.config import ( - StructureGlobalDriversConfig, - StructureConfig, -) +from griptape.config import StructureConfig agent = Agent( tools=[WebScraper(), TaskMemoryClient(off_prompt=False)], diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index f8c4080db..d35489f98 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -118,7 +118,7 @@ The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_liste ```python import os -from griptape.config import StructureConfig, StructureGlobalDriversConfig +from griptape.config import StructureConfig from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import ( EventListener, diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index eda790c99..e4ff8b8f3 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -130,15 +130,20 @@ from griptape.events import CompletionChunkEvent, EventListener from griptape.tasks import ToolkitTask from griptape.structures import Pipeline from griptape.tools import WebScraper, TaskMemoryClient +from griptape.config import OpenAiStructureConfig +from griptape.drivers import OpenAiChatPromptDriver pipeline = Pipeline( + config=OpenAiStructureConfig( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True) + ), event_listeners=[ EventListener( lambda e: print(e.token, end="", flush=True), event_types=[CompletionChunkEvent], ) - ] + ], ) pipeline.add_tasks( From d31ec65ac654b43bba1d6dbeca5aab43c5d8437d Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 20 May 2024 08:25:38 -0700 Subject: [PATCH 9/9] Fix config type hints --- .../config/amazon_bedrock_structure_config.py | 18 +++++++++++------- griptape/config/anthropic_structure_config.py | 14 +++++++++----- griptape/config/google_structure_config.py | 17 ++++++++++++----- griptape/config/openai_structure_config.py | 17 +++++++++++------ 4 files changed, 43 insertions(+), 23 deletions(-) diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_structure_config.py index 7caa55990..cefb97f57 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_structure_config.py @@ -1,4 +1,4 @@ -from attrs import define, field, Factory +from attrs import Factory, define, field from griptape.config import StructureConfig from griptape.drivers import ( @@ -6,8 +6,12 @@ AmazonBedrockImageQueryDriver, AmazonBedrockPromptDriver, AmazonBedrockTitanEmbeddingDriver, - BedrockClaudePromptModelDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BasePromptDriver, + BaseVectorStoreDriver, BedrockClaudeImageQueryModelDriver, + BedrockClaudePromptModelDriver, BedrockTitanImageGenerationModelDriver, LocalVectorStoreDriver, ) @@ -15,7 +19,7 @@ @define() class AmazonBedrockStructureConfig(StructureConfig): - prompt_driver: AmazonBedrockPromptDriver = field( + prompt_driver: BasePromptDriver = field( default=Factory( lambda: AmazonBedrockPromptDriver( model="anthropic.claude-3-sonnet-20240229-v1:0", @@ -25,7 +29,7 @@ class AmazonBedrockStructureConfig(StructureConfig): ), metadata={"serializable": True}, ) - image_generation_driver: AmazonBedrockImageGenerationDriver = field( + image_generation_driver: BaseImageGenerationDriver = field( default=Factory( lambda: AmazonBedrockImageGenerationDriver( model="amazon.titan-image-generator-v1", @@ -34,7 +38,7 @@ class AmazonBedrockStructureConfig(StructureConfig): ), metadata={"serializable": True}, ) - image_query_driver: AmazonBedrockImageQueryDriver = field( + image_query_driver: BaseImageGenerationDriver = field( default=Factory( lambda: AmazonBedrockImageQueryDriver( model="anthropic.claude-3-sonnet-20240229-v1:0", @@ -43,11 +47,11 @@ class AmazonBedrockStructureConfig(StructureConfig): ), metadata={"serializable": True}, ) - embedding_driver: AmazonBedrockTitanEmbeddingDriver = field( + embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")), metadata={"serializable": True}, ) - vector_store_driver: LocalVectorStoreDriver = field( + vector_store_driver: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver( embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_structure_config.py index bfc37bfa9..8984300a5 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_structure_config.py @@ -1,9 +1,13 @@ -from attrs import define, field, Factory +from attrs import Factory, define, field from griptape.config import StructureConfig from griptape.drivers import ( AnthropicImageQueryDriver, AnthropicPromptDriver, + BaseEmbeddingDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseVectorStoreDriver, LocalVectorStoreDriver, VoyageAiEmbeddingDriver, ) @@ -11,24 +15,24 @@ @define class AnthropicStructureConfig(StructureConfig): - prompt_driver: AnthropicPromptDriver = field( + prompt_driver: BasePromptDriver = field( default=Factory(lambda: AnthropicPromptDriver(model="claude-3-opus-20240229")), metadata={"serializable": True}, kw_only=True, ) - embedding_driver: VoyageAiEmbeddingDriver = field( + embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: LocalVectorStoreDriver = field( + vector_store_driver: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")) ), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: AnthropicImageQueryDriver = field( + image_query_driver: BaseImageQueryDriver = field( default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-opus-20240229")), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_structure_config.py index b83832b6a..744d08782 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_structure_config.py @@ -1,20 +1,27 @@ -from attrs import define, field, Factory +from attrs import Factory, define, field from griptape.config import StructureConfig -from griptape.drivers import GoogleEmbeddingDriver, GooglePromptDriver, LocalVectorStoreDriver +from griptape.drivers import ( + BaseEmbeddingDriver, + BasePromptDriver, + BaseVectorStoreDriver, + GoogleEmbeddingDriver, + GooglePromptDriver, + LocalVectorStoreDriver, +) @define class GoogleStructureConfig(StructureConfig): - prompt_driver: GooglePromptDriver = field( + prompt_driver: BasePromptDriver = field( default=Factory(lambda: GooglePromptDriver(model="gemini-pro")), kw_only=True, metadata={"serializable": True} ) - embedding_driver: GoogleEmbeddingDriver = field( + embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), kw_only=True, metadata={"serializable": True}, ) - vector_store_driver: LocalVectorStoreDriver = field( + vector_store_driver: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")) ), diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index bc0a529cc..5b2e163ba 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -1,7 +1,12 @@ -from attrs import define, Factory, field +from attrs import Factory, define, field from griptape.config import StructureConfig from griptape.drivers import ( + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseVectorStoreDriver, LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver, @@ -12,25 +17,25 @@ @define class OpenAiStructureConfig(StructureConfig): - prompt_driver: OpenAiChatPromptDriver = field( + prompt_driver: BasePromptDriver = field( default=Factory(lambda: OpenAiChatPromptDriver(model="gpt-4o")), metadata={"serializable": True}, kw_only=True ) - image_generation_driver: OpenAiImageGenerationDriver = field( + image_generation_driver: BaseImageGenerationDriver = field( default=Factory(lambda: OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512")), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: OpenAiVisionImageQueryDriver = field( + image_query_driver: BaseImageQueryDriver = field( default=Factory(lambda: OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview")), kw_only=True, metadata={"serializable": True}, ) - embedding_driver: OpenAiEmbeddingDriver = field( + embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: LocalVectorStoreDriver = field( + vector_store_driver: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")) ),