diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index da9a6c05d0..f8c4080dbd 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -134,10 +134,8 @@ agent = Agent( ) ], config=StructureConfig( - global_drivers=StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver( - model="gpt-3.5-turbo", temperature=0.7 - ), + prompt_driver=OpenAiChatPromptDriver( + model="gpt-3.5-turbo", temperature=0.7 ) ), event_listeners=[ diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index 4848eeda8c..d716205c81 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -13,6 +13,7 @@ BaseImageQueryDriver, BasePromptDriver, BaseVectorStoreDriver, + BaseTextToSpeechDriver, ) from griptape.utils import dict_merge @@ -27,6 +28,7 @@ class BaseStructureConfig(BaseConfig, ABC): conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True} ) + text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index 363bc60343..63f1ea9f38 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -15,6 +15,8 @@ DummyPromptDriver, DummyImageQueryDriver, BaseImageQueryDriver, + BaseTextToSpeechDriver, + DummyTextToSpeechDriver, ) @@ -38,3 +40,6 @@ class StructureConfig(BaseStructureConfig): conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True} ) + text_to_speech_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True} + ) diff --git a/griptape/config/structure_global_drivers_config.py b/griptape/config/structure_global_drivers_config.py deleted file mode 100644 index b599039a26..0000000000 --- a/griptape/config/structure_global_drivers_config.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional - -from attrs import Factory, define, field - -from griptape.drivers import ( - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BasePromptDriver, - BaseVectorStoreDriver, - DummyVectorStoreDriver, - DummyEmbeddingDriver, - DummyImageGenerationDriver, - DummyPromptDriver, - DummyImageQueryDriver, - BaseImageQueryDriver, - BaseTextToSpeechDriver, -) -from griptape.drivers.text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureGlobalDriversConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) - image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True} - ) - image_query_driver: BaseImageQueryDriver = field( - kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True} - ) - embedding_driver: BaseEmbeddingDriver = field( - kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True} - ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True} - ) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( - default=None, kw_only=True, metadata={"serializable": True} - ) - text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True} - ) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index ab90b1bbba..8a69227c5d 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -37,7 +37,7 @@ def text_to_speech_engine(self) -> TextToSpeechEngine: if self._text_to_speech_engine is None: if self.structure is not None: self._text_to_speech_engine = TextToSpeechEngine( - text_to_speech_driver=self.structure.config.global_drivers.text_to_speech_driver + text_to_speech_driver=self.structure.config.text_to_speech_driver ) else: raise ValueError("Audio Generation Engine is not set.") diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index d787897cfb..66ca44bb53 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -51,6 +51,7 @@ def test_to_dict(self, config): "type": "LocalVectorStoreDriver", }, "type": "AmazonBedrockStructureConfig", + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 596365eb20..9f014092a1 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -44,6 +44,7 @@ def test_to_dict(self, config): }, }, "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index dcca9e29dc..f089b611b3 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -41,6 +41,7 @@ def test_to_dict(self, config): }, }, "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index a2df522160..bd8db27cd7 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -17,7 +17,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, - "model": "gpt-4", + "model": "gpt-4o", "organization": None, "response_format": None, "seed": None, @@ -27,6 +27,7 @@ def test_to_dict(self, config): "user": "", }, "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 5cc3e2561d..9e1b000389 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -19,6 +19,7 @@ def test_to_dict(self, config): "embedding_driver": {"type": "DummyEmbeddingDriver"}, "type": "DummyVectorStoreDriver", }, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, } def test_from_dict(self, config):