diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index bd43f24048..0e301bc601 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -164,7 +164,7 @@ from griptape.tools import WebScraper pipeline = Pipeline() -pipeline.config.global_drivers.prompt_driver.stream = True +pipeline.config.prompt_driver.stream = True pipeline.add_tasks(ToolkitTask("Based on https://griptape.ai, tell me what griptape is.", tools=[WebScraper()])) for artifact in Stream(pipeline).run(): diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index c392200f9b..b54c426d82 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -122,7 +122,7 @@ from griptape.config import AmazonBedrockStructureConfig from griptape.drivers import AmazonBedrockCohereEmbeddingDriver custom_config = AmazonBedrockStructureConfig() -custom_config.global_drivers.embedding_driver = AmazonBedrockCohereEmbeddingDriver() +custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() custom_config.merge_config( { "task_memory": { diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 273a338f20..8495560178 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -1,12 +1,5 @@ from .base_config import BaseConfig -from .structure_global_drivers_config import StructureGlobalDriversConfig -from .structure_task_memory_extraction_engine_csv_config import StructureTaskMemoryExtractionEngineCsvConfig -from .structure_task_memory_extraction_engine_json_config import StructureTaskMemoryExtractionEngineJsonConfig -from .structure_task_memory_extraction_engine_config import StructureTaskMemoryExtractionEngineConfig -from .structure_task_memory_query_engine_config import StructureTaskMemoryQueryEngineConfig -from .structure_task_memory_summary_engine_config import StructureTaskMemorySummaryEngineConfig -from .structure_task_memory_config import StructureTaskMemoryConfig from .base_structure_config import BaseStructureConfig from .structure_config import StructureConfig @@ -19,13 +12,6 @@ __all__ = [ "BaseConfig", "BaseStructureConfig", - "StructureTaskMemoryConfig", - "StructureGlobalDriversConfig", - "StructureTaskMemoryQueryEngineConfig", - "StructureTaskMemorySummaryEngineConfig", - "StructureTaskMemoryExtractionEngineConfig", - "StructureTaskMemoryExtractionEngineCsvConfig", - "StructureTaskMemoryExtractionEngineJsonConfig", "StructureConfig", "OpenAiStructureConfig", "AmazonBedrockStructureConfig", diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_structure_config.py index 54b8d91c7e..ff8d6b589a 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_structure_config.py @@ -1,15 +1,6 @@ -from attrs import Factory, define, field +from attrs import define -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) +from griptape.config import StructureConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, AmazonBedrockImageQueryDriver, @@ -23,47 +14,19 @@ @define() -class AmazonBedrockStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - stream=False, - prompt_model_driver=BedrockClaudePromptModelDriver(), - ), - image_generation_driver=AmazonBedrockImageGenerationDriver( - model="amazon.titan-image-generator-v1", - image_generation_model_driver=BedrockTitanImageGenerationModelDriver(), - ), - image_query_driver=AmazonBedrockImageQueryDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - image_query_model_driver=BedrockClaudeImageQueryModelDriver(), - ), - embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") - ), - ) - ), - kw_only=True, - metadata={"serializable": True}, +class AmazonBedrockStructureConfig(StructureConfig): + prompt_driver = AmazonBedrockPromptDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", + stream=False, + prompt_model_driver=BedrockClaudePromptModelDriver(), + ) + image_generation_driver = AmazonBedrockImageGenerationDriver( + model="amazon.titan-image-generator-v1", image_generation_model_driver=BedrockTitanImageGenerationModelDriver() + ) + image_query_driver = AmazonBedrockImageQueryDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", image_query_model_driver=BedrockClaudeImageQueryModelDriver() ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=self.global_drivers.vector_store_driver, - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, + embedding_driver = AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") + vector_store_driver = LocalVectorStoreDriver( + embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") ) diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_structure_config.py index 06978a5c2e..2eb67042d1 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_structure_config.py @@ -1,54 +1,17 @@ -from attrs import Factory, define, field +from attrs import define -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) +from griptape.config import StructureConfig from griptape.drivers import ( - LocalVectorStoreDriver, - AnthropicPromptDriver, AnthropicImageQueryDriver, + AnthropicPromptDriver, + LocalVectorStoreDriver, VoyageAiEmbeddingDriver, ) @define -class AnthropicStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=AnthropicPromptDriver(model="claude-3-opus-20240229"), - embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2") - ), - image_query_driver=AnthropicImageQueryDriver(model="claude-3-opus-20240229"), - ) - ), - kw_only=True, - metadata={"serializable": True}, - ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, - ) +class AnthropicStructureConfig(StructureConfig): + prompt_driver = AnthropicPromptDriver(model="claude-3-opus-20240229") + embedding_driver = VoyageAiEmbeddingDriver(model="voyage-large-2") + vector_store_driver = LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")) + image_query_driver = AnthropicImageQueryDriver(model="claude-3-opus-20240229") diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index afa8e40120..4848eeda8c 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -1,17 +1,32 @@ from __future__ import annotations from abc import ABC +from typing import Optional from attr import define, field -from griptape.config import BaseConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig +from griptape.config import BaseConfig +from griptape.drivers import ( + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseVectorStoreDriver, +) from griptape.utils import dict_merge @define class BaseStructureConfig(BaseConfig, ABC): - global_drivers: StructureGlobalDriversConfig = field(kw_only=True, metadata={"serializable": True}) - task_memory: StructureTaskMemoryConfig = field(kw_only=True, metadata={"serializable": True}) + prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) + image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) + image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) + embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) + vector_store_driver: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) + conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + default=None, kw_only=True, metadata={"serializable": True} + ) def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_structure_config.py index 9ba40622ff..22d2994080 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_structure_config.py @@ -1,48 +1,11 @@ -from attrs import Factory, define, field +from attrs import define -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) -from griptape.drivers import LocalVectorStoreDriver, GooglePromptDriver, GoogleEmbeddingDriver +from griptape.config import StructureConfig +from griptape.drivers import GoogleEmbeddingDriver, GooglePromptDriver, LocalVectorStoreDriver @define -class GoogleStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=GooglePromptDriver(model="gemini-pro"), - embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001") - ), - ) - ), - kw_only=True, - metadata={"serializable": True}, - ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, - ) +class GoogleStructureConfig(StructureConfig): + prompt_driver = GooglePromptDriver(model="gemini-pro") + embedding_driver = GoogleEmbeddingDriver(model="models/embedding-001") + vector_store_driver = LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")) diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 283fca2d1c..71143c61e5 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -1,15 +1,6 @@ -from attrs import Factory, define, field +from attrs import define -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) +from griptape.config import StructureConfig from griptape.drivers import ( LocalVectorStoreDriver, OpenAiChatPromptDriver, @@ -20,37 +11,9 @@ @define -class OpenAiStructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory( - lambda: StructureGlobalDriversConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), - image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"), - image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"), - embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"), - vector_store_driver=LocalVectorStoreDriver( - embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small") - ), - ) - ), - kw_only=True, - metadata={"serializable": True}, - ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, - ) +class OpenAiStructureConfig(StructureConfig): + prompt_driver = OpenAiChatPromptDriver(model="gpt-4") + image_generation_driver = OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512") + image_query_driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview") + embedding_driver = OpenAiEmbeddingDriver(model="text-embedding-3-small") + vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")) diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index 6381450f42..363bc60343 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -1,38 +1,40 @@ from attrs import Factory, define, field +from typing import Optional -from griptape.config import ( - BaseStructureConfig, - StructureGlobalDriversConfig, - StructureTaskMemoryConfig, - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryExtractionEngineCsvConfig, - StructureTaskMemoryExtractionEngineJsonConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, +from griptape.config import BaseStructureConfig + +from griptape.drivers import ( + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BasePromptDriver, + BaseVectorStoreDriver, + DummyVectorStoreDriver, + DummyEmbeddingDriver, + DummyImageGenerationDriver, + DummyPromptDriver, + DummyImageQueryDriver, + BaseImageQueryDriver, ) -from griptape.drivers import LocalVectorStoreDriver @define class StructureConfig(BaseStructureConfig): - global_drivers: StructureGlobalDriversConfig = field( - default=Factory(lambda: StructureGlobalDriversConfig()), kw_only=True, metadata={"serializable": True} + prompt_driver: BasePromptDriver = field( + kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} + ) + image_generation_driver: BaseImageGenerationDriver = field( + kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True} + ) + image_query_driver: BaseImageQueryDriver = field( + kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True} + ) + embedding_driver: BaseEmbeddingDriver = field( + kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True} + ) + vector_store_driver: BaseVectorStoreDriver = field( + default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True} ) - task_memory: StructureTaskMemoryConfig = field( - default=Factory( - lambda self: StructureTaskMemoryConfig( - query_engine=StructureTaskMemoryQueryEngineConfig( - prompt_driver=self.global_drivers.prompt_driver, - vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), - ), - extraction_engine=StructureTaskMemoryExtractionEngineConfig( - csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), - json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), - ), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, + conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + default=None, kw_only=True, metadata={"serializable": True} ) diff --git a/griptape/config/structure_global_drivers_config.py b/griptape/config/structure_global_drivers_config.py deleted file mode 100644 index b3e2e879fd..0000000000 --- a/griptape/config/structure_global_drivers_config.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Optional - -from attrs import Factory, define, field - -from griptape.drivers import ( - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BasePromptDriver, - BaseVectorStoreDriver, - DummyVectorStoreDriver, - DummyEmbeddingDriver, - DummyImageGenerationDriver, - DummyPromptDriver, - DummyImageQueryDriver, - BaseImageQueryDriver, -) -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureGlobalDriversConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) - image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True} - ) - image_query_driver: BaseImageQueryDriver = field( - kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True} - ) - embedding_driver: BaseEmbeddingDriver = field( - kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True} - ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True} - ) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( - default=None, kw_only=True, metadata={"serializable": True} - ) diff --git a/griptape/config/structure_task_memory_config.py b/griptape/config/structure_task_memory_config.py deleted file mode 100644 index 3b8648dcff..0000000000 --- a/griptape/config/structure_task_memory_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from attrs import Factory, define, field - -from griptape.config import ( - StructureTaskMemoryExtractionEngineConfig, - StructureTaskMemoryQueryEngineConfig, - StructureTaskMemorySummaryEngineConfig, -) -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryConfig(SerializableMixin): - query_engine: StructureTaskMemoryQueryEngineConfig = field( - kw_only=True, default=Factory(lambda: StructureTaskMemoryQueryEngineConfig()), metadata={"serializable": True} - ) - extraction_engine: StructureTaskMemoryExtractionEngineConfig = field( - kw_only=True, - default=Factory(lambda: StructureTaskMemoryExtractionEngineConfig()), - metadata={"serializable": True}, - ) - summary_engine: StructureTaskMemorySummaryEngineConfig = field( - kw_only=True, default=Factory(lambda: StructureTaskMemorySummaryEngineConfig()), metadata={"serializable": True} - ) diff --git a/griptape/config/structure_task_memory_extraction_engine_config.py b/griptape/config/structure_task_memory_extraction_engine_config.py deleted file mode 100644 index 8c2a58f02a..0000000000 --- a/griptape/config/structure_task_memory_extraction_engine_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from attrs import Factory, define, field - -from griptape.config import StructureTaskMemoryExtractionEngineCsvConfig, StructureTaskMemoryExtractionEngineJsonConfig -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryExtractionEngineConfig(SerializableMixin): - csv: StructureTaskMemoryExtractionEngineCsvConfig = field( - kw_only=True, - default=Factory(lambda: StructureTaskMemoryExtractionEngineCsvConfig()), - metadata={"serializable": True}, - ) - json: StructureTaskMemoryExtractionEngineJsonConfig = field( - kw_only=True, - default=Factory(lambda: StructureTaskMemoryExtractionEngineJsonConfig()), - metadata={"serializable": True}, - ) diff --git a/griptape/config/structure_task_memory_extraction_engine_csv_config.py b/griptape/config/structure_task_memory_extraction_engine_csv_config.py deleted file mode 100644 index cce5f3e029..0000000000 --- a/griptape/config/structure_task_memory_extraction_engine_csv_config.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import define, field, Factory - -from griptape.drivers import BasePromptDriver, DummyPromptDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryExtractionEngineCsvConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) diff --git a/griptape/config/structure_task_memory_extraction_engine_json_config.py b/griptape/config/structure_task_memory_extraction_engine_json_config.py deleted file mode 100644 index 04210b8c86..0000000000 --- a/griptape/config/structure_task_memory_extraction_engine_json_config.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import define, field, Factory - -from griptape.drivers import BasePromptDriver, DummyPromptDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryExtractionEngineJsonConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) diff --git a/griptape/config/structure_task_memory_query_engine_config.py b/griptape/config/structure_task_memory_query_engine_config.py deleted file mode 100644 index 30d6bbcf94..0000000000 --- a/griptape/config/structure_task_memory_query_engine_config.py +++ /dev/null @@ -1,22 +0,0 @@ -from attrs import Factory, define, field - -from griptape.drivers import ( - BasePromptDriver, - BaseVectorStoreDriver, - DummyVectorStoreDriver, - DummyEmbeddingDriver, - DummyPromptDriver, -) -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemoryQueryEngineConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) - vector_store_driver: BaseVectorStoreDriver = field( - kw_only=True, - default=Factory(lambda: DummyVectorStoreDriver(embedding_driver=DummyEmbeddingDriver())), - metadata={"serializable": True}, - ) diff --git a/griptape/config/structure_task_memory_summary_engine_config.py b/griptape/config/structure_task_memory_summary_engine_config.py deleted file mode 100644 index 100f9d8f1c..0000000000 --- a/griptape/config/structure_task_memory_summary_engine_config.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import Factory, define, field - -from griptape.drivers import BasePromptDriver, DummyPromptDriver -from griptape.mixins.serializable_mixin import SerializableMixin - - -@define -class StructureTaskMemorySummaryEngineConfig(SerializableMixin): - prompt_driver: BasePromptDriver = field( - kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True} - ) diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 1f622b0d2f..fa06a3c764 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -25,7 +25,7 @@ class SummaryConversationMemory(ConversationMemory): def prompt_driver(self) -> BasePromptDriver: if self._prompt_driver is None: if self.structure is not None: - self._prompt_driver = self.structure.config.global_drivers.prompt_driver + self._prompt_driver = self.structure.config.prompt_driver else: raise ValueError("Prompt Driver is not set.") return self._prompt_driver diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 104bb6f5ed..eec3c690ef 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -11,9 +11,9 @@ @define class TextArtifactStorage(BaseArtifactStorage): query_engine: VectorQueryEngine = field(kw_only=True) - summary_engine: BaseSummaryEngine = field(kw_only=True) - csv_extraction_engine: CsvExtractionEngine = field(kw_only=True) - json_extraction_engine: JsonExtractionEngine = field(kw_only=True) + summary_engine: Optional[BaseSummaryEngine] = field(kw_only=True, default=None) + csv_extraction_engine: Optional[CsvExtractionEngine] = field(kw_only=True, default=None) + json_extraction_engine: Optional[JsonExtractionEngine] = field(kw_only=True, default=None) def can_store(self, artifact: BaseArtifact) -> bool: return isinstance(artifact, TextArtifact) @@ -28,6 +28,8 @@ def load_artifacts(self, namespace: str) -> ListArtifact: return self.query_engine.load_artifacts(namespace) def summarize(self, namespace: str) -> TextArtifact: + if self.summary_engine is None: + raise ValueError("Summary engine is not set.") return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace)) def query(self, namespace: str, query: str, metadata: Any = None) -> TextArtifact: diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index aee95dbf89..cc123f8f80 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -1,10 +1,12 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -from attr import define, field +from attr import define, field, Factory from griptape.tools import BaseTool from griptape.memory.structure import Run from griptape.structures import Structure from griptape.tasks import PromptTask, ToolkitTask +from griptape.drivers import BasePromptDriver, BaseEmbeddingDriver +from griptape.config import BaseStructureConfig if TYPE_CHECKING: from griptape.tasks import BaseTask @@ -15,6 +17,12 @@ class Agent(Structure): input_template: str = field(default=PromptTask.DEFAULT_INPUT_TEMPLATE) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) + stream: Optional[bool] = field(default=None, kw_only=True) + prompt_driver: Optional[BasePromptDriver] = field(default=None) + embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) + config: BaseStructureConfig = field( + default=Factory(lambda self: self.default_config, takes_self=True), kw_only=True + ) def __attrs_post_init__(self) -> None: super().__attrs_post_init__() @@ -32,6 +40,18 @@ def __attrs_post_init__(self) -> None: def task(self) -> BaseTask: return self.tasks[0] + @prompt_driver.validator # pyright: ignore + def validate_prompt_driver(self, attribute, value): + pass + + @embedding_driver.validator # pyright: ignore + def validate_embedding_driver(self, attribute, value): + pass + + @stream.validator # pyright: ignore + def validate_stream(self, attribute, value): + pass + def add_task(self, task: BaseTask) -> BaseTask: self.tasks.clear() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index e807393f4e..a55d478337 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -56,8 +56,7 @@ class Structure(ABC): event_listeners: list[EventListener] = field(factory=list, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( default=Factory( - lambda self: ConversationMemory(driver=self.config.global_drivers.conversation_memory_driver), - takes_self=True, + lambda self: ConversationMemory(driver=self.config.conversation_memory_driver), takes_self=True ), kw_only=True, ) @@ -98,19 +97,17 @@ def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]: @prompt_driver.validator # pyright: ignore def validate_prompt_driver(self, attribute, value): if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.global_drivers.prompt_driver` instead.") + deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver` instead.") @embedding_driver.validator # pyright: ignore def validate_embedding_driver(self, attribute, value): if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.global_drivers.embedding_driver` instead.") + deprecation_warn(f"`{attribute.name}` is deprecated, use `config.embedding_driver` instead.") @stream.validator # pyright: ignore def validate_stream(self, attribute, value): if value is not None: - deprecation_warn( - f"`{attribute.name}` is deprecated, use `config.global_drivers.prompt_driver.stream` instead." - ) + deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver.stream` instead.") @property def execution_args(self) -> tuple: @@ -162,15 +159,9 @@ def default_config(self) -> BaseStructureConfig: vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) - config.global_drivers.prompt_driver = prompt_driver - config.global_drivers.vector_store_driver = vector_store_driver - config.global_drivers.embedding_driver = embedding_driver - - config.task_memory.query_engine.prompt_driver = prompt_driver - config.task_memory.query_engine.vector_store_driver = vector_store_driver - config.task_memory.summary_engine.prompt_driver = prompt_driver - config.task_memory.extraction_engine.csv.prompt_driver = prompt_driver - config.task_memory.extraction_engine.json.prompt_driver = prompt_driver + config.prompt_driver = prompt_driver + config.vector_store_driver = vector_store_driver + config.embedding_driver = embedding_driver else: config = OpenAiStructureConfig() @@ -178,45 +169,15 @@ def default_config(self) -> BaseStructureConfig: @property def default_task_memory(self) -> TaskMemory: - global_drivers = self.config.global_drivers - task_memory = self.config.task_memory - return TaskMemory( artifact_storages={ TextArtifact: TextArtifactStorage( query_engine=VectorQueryEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.query_engine.prompt_driver, DummyPromptDriver) - else task_memory.query_engine.prompt_driver - ), - vector_store_driver=( - global_drivers.vector_store_driver - if isinstance(task_memory.query_engine.prompt_driver, DummyVectorStoreDriver) - else task_memory.query_engine.vector_store_driver - ), - ), - summary_engine=PromptSummaryEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.summary_engine.prompt_driver, DummyPromptDriver) - else task_memory.summary_engine.prompt_driver - ) - ), - csv_extraction_engine=CsvExtractionEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.extraction_engine.csv.prompt_driver, DummyPromptDriver) - else task_memory.extraction_engine.csv.prompt_driver - ) - ), - json_extraction_engine=JsonExtractionEngine( - prompt_driver=( - global_drivers.prompt_driver - if isinstance(task_memory.extraction_engine.json.prompt_driver, DummyPromptDriver) - else task_memory.extraction_engine.json.prompt_driver - ) + prompt_driver=self.config.prompt_driver, vector_store_driver=self.config.vector_store_driver ), + summary_engine=PromptSummaryEngine(prompt_driver=self.config.prompt_driver), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=self.config.prompt_driver), + json_extraction_engine=JsonExtractionEngine(prompt_driver=self.config.prompt_driver), ), BlobArtifact: BlobArtifactStorage(), } diff --git a/griptape/tasks/csv_extraction_task.py b/griptape/tasks/csv_extraction_task.py index ae71ea5d4f..2f5f3db567 100644 --- a/griptape/tasks/csv_extraction_task.py +++ b/griptape/tasks/csv_extraction_task.py @@ -12,9 +12,7 @@ class CsvExtractionTask(ExtractionTask): def extraction_engine(self) -> CsvExtractionEngine: if self._extraction_engine is None: if self.structure is not None: - self._extraction_engine = CsvExtractionEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver - ) + self._extraction_engine = CsvExtractionEngine(prompt_driver=self.structure.config.prompt_driver) else: raise ValueError("Extraction Engine is not set.") return self._extraction_engine diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index 5a8c49f9cd..faa037eebb 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -56,9 +56,7 @@ def input( def image_query_engine(self) -> ImageQueryEngine: if self._image_query_engine is None: if self.structure is not None: - self._image_query_engine = ImageQueryEngine( - image_query_driver=self.structure.config.global_drivers.image_query_driver - ) + self._image_query_engine = ImageQueryEngine(image_query_driver=self.structure.config.image_query_driver) else: raise ValueError("Image Query Engine is not set.") return self._image_query_engine diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 028ae336e1..f3b2edb7a9 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -57,7 +57,7 @@ def image_generation_engine(self) -> InpaintingImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = InpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/tasks/json_extraction_task.py b/griptape/tasks/json_extraction_task.py index a43b1e1e2f..e1f082fd82 100644 --- a/griptape/tasks/json_extraction_task.py +++ b/griptape/tasks/json_extraction_task.py @@ -12,9 +12,7 @@ class JsonExtractionTask(ExtractionTask): def extraction_engine(self) -> JsonExtractionEngine: if self._extraction_engine is None: if self.structure is not None: - self._extraction_engine = JsonExtractionEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver - ) + self._extraction_engine = JsonExtractionEngine(prompt_driver=self.structure.config.prompt_driver) else: raise ValueError("Extraction Engine is not set.") return self._extraction_engine diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index 203a6c2ba2..575f19f6bb 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -56,7 +56,7 @@ def image_generation_engine(self) -> OutpaintingImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = OutpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 7d9bad53eb..83c5eb5e28 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -50,7 +50,7 @@ def image_generation_engine(self) -> PromptImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = PromptImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 525e21eefa..16f7c6dac6 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -42,7 +42,7 @@ def prompt_stack(self) -> PromptStack: def prompt_driver(self) -> BasePromptDriver: if self._prompt_driver is None: if self.structure is not None: - self._prompt_driver = self.structure.config.global_drivers.prompt_driver + self._prompt_driver = self.structure.config.prompt_driver else: raise ValueError("Prompt Driver is not set") return self._prompt_driver diff --git a/griptape/tasks/text_query_task.py b/griptape/tasks/text_query_task.py index 88fb94d106..5fee68103e 100644 --- a/griptape/tasks/text_query_task.py +++ b/griptape/tasks/text_query_task.py @@ -18,8 +18,8 @@ def query_engine(self) -> BaseQueryEngine: if self._query_engine is None: if self.structure is not None: self._query_engine = VectorQueryEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver, - vector_store_driver=self.structure.config.global_drivers.vector_store_driver, + prompt_driver=self.structure.config.prompt_driver, + vector_store_driver=self.structure.config.vector_store_driver, ) else: raise ValueError("Query Engine is not set.") diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index 7cb6a8a0b5..f10f851d00 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -17,9 +17,7 @@ class TextSummaryTask(BaseTextInputTask): def summary_engine(self) -> Optional[BaseSummaryEngine]: if self._summary_engine is None: if self.structure is not None: - self._summary_engine = PromptSummaryEngine( - prompt_driver=self.structure.config.global_drivers.prompt_driver - ) + self._summary_engine = PromptSummaryEngine(prompt_driver=self.structure.config.prompt_driver) else: raise ValueError("Summary Engine is not set.") return self._summary_engine diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index 650d395c3b..1242bc59be 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -56,7 +56,7 @@ def image_generation_engine(self) -> VariationImageGenerationEngine: if self._image_generation_engine is None: if self.structure is not None: self._image_generation_engine = VariationImageGenerationEngine( - image_generation_driver=self.structure.config.global_drivers.image_generation_driver + image_generation_driver=self.structure.config.image_generation_driver ) else: raise ValueError("Image Generation Engine is not set.") diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 110a427da4..549d93e53f 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -21,7 +21,7 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - if self.structure.config.global_drivers.prompt_driver.stream: + if self.structure.config.prompt_driver.stream: print(text, end="", flush=True) else: print(text) @@ -36,7 +36,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if self.structure.config.global_drivers.prompt_driver.stream: + if self.structure.config.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/prompt_stack.py b/griptape/utils/prompt_stack.py index f4dda14626..a5f3360309 100644 --- a/griptape/utils/prompt_stack.py +++ b/griptape/utils/prompt_stack.py @@ -66,7 +66,7 @@ def add_conversation_memory(self, memory: BaseConversationMemory, index: Optiona if memory.autoprune and hasattr(memory, "structure"): should_prune = True - prompt_driver = memory.structure.config.global_drivers.prompt_driver + prompt_driver = memory.structure.config.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 80d3ea5a17..3ebd4225e3 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -33,7 +33,7 @@ class Stream: @structure.validator # pyright: ignore def validate_structure(self, _, structure: Structure): - if structure and not structure.config.global_drivers.prompt_driver.stream: + if structure and not structure.config.prompt_driver.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index 04fe4d7a13..d787897cfb 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -14,114 +14,43 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { - "global_drivers": { - "conversation_memory_driver": None, + "conversation_memory_driver": None, + "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation_driver": { + "image_generation_model_driver": { + "cfg_scale": 7, + "outpainting_mode": "PRECISE", + "quality": "standard", + "type": "BedrockTitanImageGenerationModelDriver", + }, + "image_height": 512, + "image_width": 512, + "model": "amazon.titan-image-generator-v1", + "seed": None, + "type": "AmazonBedrockImageGenerationDriver", + }, + "image_query_driver": { + "type": "AmazonBedrockImageQueryDriver", + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "max_tokens": 256, + "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, + }, + "prompt_driver": { + "max_tokens": None, + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, + "stream": False, + "temperature": 0.1, + "type": "AmazonBedrockPromptDriver", + }, + "vector_store_driver": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, - "image_generation_driver": { - "image_generation_model_driver": { - "cfg_scale": 7, - "outpainting_mode": "PRECISE", - "quality": "standard", - "type": "BedrockTitanImageGenerationModelDriver", - }, - "image_height": 512, - "image_width": 512, - "model": "amazon.titan-image-generator-v1", - "seed": None, - "type": "AmazonBedrockImageGenerationDriver", - }, - "image_query_driver": { - "type": "AmazonBedrockImageQueryDriver", - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "max_tokens": 256, - "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, - }, - "prompt_driver": { - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, - "stream": False, - "temperature": 0.1, - "type": "AmazonBedrockPromptDriver", - }, - "type": "StructureGlobalDriversConfig", - "vector_store_driver": { - "embedding_driver": { - "model": "amazon.titan-embed-text-v1", - "type": "AmazonBedrockTitanEmbeddingDriver", - }, - "type": "LocalVectorStoreDriver", - }, + "type": "LocalVectorStoreDriver", }, "type": "AmazonBedrockStructureConfig", - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, - "stream": False, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "AmazonBedrockTitanEmbeddingDriver", - "model": "amazon.titan-embed-text-v1", - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": { - "type": "BedrockClaudePromptModelDriver", - "top_k": 250, - "top_p": 0.999, - }, - "stream": False, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": { - "type": "BedrockClaudePromptModelDriver", - "top_k": 250, - "top_p": 0.999, - }, - "stream": False, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "AmazonBedrockPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, - "stream": False, - }, - }, - }, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 5653ddb46a..596365eb20 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -15,100 +15,35 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "AnthropicStructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": { - "type": "AnthropicImageQueryDriver", - "model": "claude-3-opus-20240229", - "max_tokens": 256, - }, + "prompt_driver": { + "type": "AnthropicPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "model": "claude-3-opus-20240229", + "top_p": 0.999, + "top_k": 250, + }, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": { + "type": "AnthropicImageQueryDriver", + "model": "claude-3-opus-20240229", + "max_tokens": 256, + }, + "embedding_driver": { + "type": "VoyageAiEmbeddingDriver", + "model": "voyage-large-2", + "input_type": "document", + }, + "vector_store_driver": { + "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", "input_type": "document", }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "VoyageAiEmbeddingDriver", - "model": "voyage-large-2", - "input_type": "document", - }, - }, - "conversation_memory_driver": None, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "VoyageAiEmbeddingDriver", - "model": "voyage-large-2", - "input_type": "document", - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "AnthropicPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "claude-3-opus-20240229", - "top_p": 0.999, - "top_k": 250, - }, - }, }, + "conversation_memory_driver": None, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index ad9283d027..dcca9e29dc 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -14,99 +14,33 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "GoogleStructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, + "prompt_driver": { + "type": "GooglePromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "model": "gemini-pro", + "top_p": None, + "top_k": None, + }, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "embedding_driver": { + "type": "GoogleEmbeddingDriver", + "model": "models/embedding-001", + "task_type": "retrieval_document", + "title": None, + }, + "vector_store_driver": { + "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", "task_type": "retrieval_document", "title": None, }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "GoogleEmbeddingDriver", - "model": "models/embedding-001", - "task_type": "retrieval_document", - "title": None, - }, - }, - "conversation_memory_driver": None, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "GoogleEmbeddingDriver", - "model": "models/embedding-001", - "task_type": "retrieval_document", - "title": None, - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "GooglePromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "model": "gemini-pro", - "top_p": None, - "top_k": None, - }, - }, }, + "conversation_memory_driver": None, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index 60eeed0916..b06d6055fa 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -13,132 +13,54 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { - "type": "OpenAiStructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - "conversation_memory_driver": None, + "type": "StructureGlobalDriversConfig", + "prompt_driver": { + "type": "OpenAiChatPromptDriver", + "base_url": None, + "model": "gpt-4", + "organization": None, + "response_format": None, + "seed": None, + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "user": "", + }, + "conversation_memory_driver": None, + "embedding_driver": { + "base_url": None, + "model": "text-embedding-3-small", + "organization": None, + "type": "OpenAiEmbeddingDriver", + }, + "image_generation_driver": { + "api_version": None, + "base_url": None, + "image_size": "512x512", + "model": "dall-e-2", + "organization": None, + "quality": "standard", + "response_format": "b64_json", + "style": None, + "type": "OpenAiImageGenerationDriver", + }, + "image_query_driver": { + "api_version": None, + "base_url": None, + "image_quality": "auto", + "max_tokens": 256, + "model": "gpt-4-vision-preview", + "organization": None, + "type": "OpenAiVisionImageQueryDriver", + }, + "vector_store_driver": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", "organization": None, "type": "OpenAiEmbeddingDriver", }, - "image_generation_driver": { - "api_version": None, - "base_url": None, - "image_size": "512x512", - "model": "dall-e-2", - "organization": None, - "quality": "standard", - "response_format": "b64_json", - "style": None, - "type": "OpenAiImageGenerationDriver", - }, - "image_query_driver": { - "api_version": None, - "base_url": None, - "image_quality": "auto", - "max_tokens": 256, - "model": "gpt-4-vision-preview", - "organization": None, - "type": "OpenAiVisionImageQueryDriver", - }, - "vector_store_driver": { - "embedding_driver": { - "base_url": None, - "model": "text-embedding-3-small", - "organization": None, - "type": "OpenAiEmbeddingDriver", - }, - "type": "LocalVectorStoreDriver", - }, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "base_url": None, - "type": "OpenAiChatPromptDriver", - "model": "gpt-4", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": { - "type": "OpenAiEmbeddingDriver", - "base_url": None, - "organization": None, - "model": "text-embedding-3-small", - }, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "OpenAiChatPromptDriver", - "base_url": None, - "model": "gpt-4", - "organization": None, - "response_format": None, - "seed": None, - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "user": "", - }, - }, + "type": "LocalVectorStoreDriver", }, } diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 13060cf92b..5cc3e2561d 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -10,63 +10,14 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "StructureConfig", - "global_drivers": { - "type": "StructureGlobalDriversConfig", - "prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}, - "conversation_memory_driver": None, + "prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}, + "conversation_memory_driver": None, + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "vector_store_driver": { "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "vector_store_driver": { - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "type": "DummyVectorStoreDriver", - }, - }, - "task_memory": { - "type": "StructureTaskMemoryConfig", - "query_engine": { - "type": "StructureTaskMemoryQueryEngineConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "stream": False, - "temperature": 0.1, - "max_tokens": None, - }, - "vector_store_driver": { - "type": "LocalVectorStoreDriver", - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - }, - }, - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, - "json": { - "type": "StructureTaskMemoryExtractionEngineJsonConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, - }, - "summary_engine": { - "type": "StructureTaskMemorySummaryEngineConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, + "type": "DummyVectorStoreDriver", }, } @@ -78,19 +29,11 @@ def test_unchanged_merge_config(self, config): config.merge_config( { "type": "StructureConfig", - "task_memory": { - "extraction_engine": { - "type": "StructureTaskMemoryExtractionEngineConfig", - "csv": { - "type": "StructureTaskMemoryExtractionEngineCsvConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - }, - } + "prompt_driver": { + "type": "DummyPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, }, } ).to_dict() @@ -99,12 +42,12 @@ def test_unchanged_merge_config(self, config): def test_changed_merge_config(self, config): config = config.merge_config( - {"task_memory": {"extraction_engine": {"csv": {"prompt_driver": {"stream": True}}}}} + {"prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}} ) - assert config.task_memory.extraction_engine.csv.prompt_driver.stream is True + assert config.prompt_driver.temperature == 0.1 def test_dot_update(self, config): - config.task_memory.extraction_engine.csv.prompt_driver.stream = True + config.prompt_driver.max_tokens = 10 - assert config.task_memory.extraction_engine.csv.prompt_driver.stream is True + assert config.prompt_driver.max_tokens == 10 diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 1ffa24ecdd..0366652b06 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -14,10 +14,8 @@ def task(self): def test_run(self, task): mock_config = MockStructureConfig() - assert isinstance(mock_config.global_drivers.prompt_driver, MockPromptDriver) - mock_config.global_drivers.prompt_driver.mock_output = ( - '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - ) + assert isinstance(mock_config.prompt_driver, MockPromptDriver) + mock_config.prompt_driver.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' agent = Agent(config=mock_config) agent.add_task(task)