Skip to content

Commit

Permalink
Refactor how configs work
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Apr 30, 2024
1 parent a33324b commit e7d5fb8
Show file tree
Hide file tree
Showing 39 changed files with 278 additions and 923 deletions.
2 changes: 1 addition & 1 deletion docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ from griptape.tools import WebScraper


pipeline = Pipeline()
pipeline.config.global_drivers.prompt_driver.stream = True
pipeline.config.prompt_driver.stream = True
pipeline.add_tasks(ToolkitTask("Based on https://griptape.ai, tell me what griptape is.", tools=[WebScraper()]))

for artifact in Stream(pipeline).run():
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ from griptape.config import AmazonBedrockStructureConfig
from griptape.drivers import AmazonBedrockCohereEmbeddingDriver

custom_config = AmazonBedrockStructureConfig()
custom_config.global_drivers.embedding_driver = AmazonBedrockCohereEmbeddingDriver()
custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver()
custom_config.merge_config(
{
"task_memory": {
Expand Down
14 changes: 0 additions & 14 deletions griptape/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from .base_config import BaseConfig

from .structure_global_drivers_config import StructureGlobalDriversConfig
from .structure_task_memory_extraction_engine_csv_config import StructureTaskMemoryExtractionEngineCsvConfig
from .structure_task_memory_extraction_engine_json_config import StructureTaskMemoryExtractionEngineJsonConfig
from .structure_task_memory_extraction_engine_config import StructureTaskMemoryExtractionEngineConfig
from .structure_task_memory_query_engine_config import StructureTaskMemoryQueryEngineConfig
from .structure_task_memory_summary_engine_config import StructureTaskMemorySummaryEngineConfig
from .structure_task_memory_config import StructureTaskMemoryConfig
from .base_structure_config import BaseStructureConfig

from .structure_config import StructureConfig
Expand All @@ -19,13 +12,6 @@
__all__ = [
"BaseConfig",
"BaseStructureConfig",
"StructureTaskMemoryConfig",
"StructureGlobalDriversConfig",
"StructureTaskMemoryQueryEngineConfig",
"StructureTaskMemorySummaryEngineConfig",
"StructureTaskMemoryExtractionEngineConfig",
"StructureTaskMemoryExtractionEngineCsvConfig",
"StructureTaskMemoryExtractionEngineJsonConfig",
"StructureConfig",
"OpenAiStructureConfig",
"AmazonBedrockStructureConfig",
Expand Down
69 changes: 16 additions & 53 deletions griptape/config/amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from attrs import Factory, define, field
from attrs import define

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.config import StructureConfig
from griptape.drivers import (
AmazonBedrockImageGenerationDriver,
AmazonBedrockImageQueryDriver,
Expand All @@ -23,47 +14,19 @@


@define()
class AmazonBedrockStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=AmazonBedrockPromptDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
stream=False,
prompt_model_driver=BedrockClaudePromptModelDriver(),
),
image_generation_driver=AmazonBedrockImageGenerationDriver(
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
),
image_query_driver=AmazonBedrockImageQueryDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
),
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
),
)
),
kw_only=True,
metadata={"serializable": True},
class AmazonBedrockStructureConfig(StructureConfig):
prompt_driver = AmazonBedrockPromptDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
stream=False,
prompt_model_driver=BedrockClaudePromptModelDriver(),
)
image_generation_driver = AmazonBedrockImageGenerationDriver(
model="amazon.titan-image-generator-v1", image_generation_model_driver=BedrockTitanImageGenerationModelDriver()
)
image_query_driver = AmazonBedrockImageQueryDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0", image_query_model_driver=BedrockClaudeImageQueryModelDriver()
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda self: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=self.global_drivers.prompt_driver,
vector_store_driver=self.global_drivers.vector_store_driver,
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver),
json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
embedding_driver = AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
vector_store_driver = LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
)
55 changes: 9 additions & 46 deletions griptape/config/anthropic_structure_config.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,17 @@
from attrs import Factory, define, field
from attrs import define

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.config import StructureConfig
from griptape.drivers import (
LocalVectorStoreDriver,
AnthropicPromptDriver,
AnthropicImageQueryDriver,
AnthropicPromptDriver,
LocalVectorStoreDriver,
VoyageAiEmbeddingDriver,
)


@define
class AnthropicStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=AnthropicPromptDriver(model="claude-3-opus-20240229"),
embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")
),
image_query_driver=AnthropicImageQueryDriver(model="claude-3-opus-20240229"),
)
),
kw_only=True,
metadata={"serializable": True},
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda self: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=self.global_drivers.prompt_driver,
vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver),
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver),
json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
class AnthropicStructureConfig(StructureConfig):
prompt_driver = AnthropicPromptDriver(model="claude-3-opus-20240229")
embedding_driver = VoyageAiEmbeddingDriver(model="voyage-large-2")
vector_store_driver = LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"))
image_query_driver = AnthropicImageQueryDriver(model="claude-3-opus-20240229")
21 changes: 18 additions & 3 deletions griptape/config/base_structure_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from __future__ import annotations

from abc import ABC
from typing import Optional

from attr import define, field

from griptape.config import BaseConfig, StructureGlobalDriversConfig, StructureTaskMemoryConfig
from griptape.config import BaseConfig
from griptape.drivers import (
BaseConversationMemoryDriver,
BaseEmbeddingDriver,
BaseImageGenerationDriver,
BaseImageQueryDriver,
BasePromptDriver,
BaseVectorStoreDriver,
)
from griptape.utils import dict_merge


@define
class BaseStructureConfig(BaseConfig, ABC):
global_drivers: StructureGlobalDriversConfig = field(kw_only=True, metadata={"serializable": True})
task_memory: StructureTaskMemoryConfig = field(kw_only=True, metadata={"serializable": True})
prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True})
image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True})
image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True})
embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True})
vector_store_driver: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True})
conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field(
default=None, kw_only=True, metadata={"serializable": True}
)

def merge_config(self, config: dict) -> BaseStructureConfig:
base_config = self.to_dict()
Expand Down
51 changes: 7 additions & 44 deletions griptape/config/google_structure_config.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,11 @@
from attrs import Factory, define, field
from attrs import define

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import LocalVectorStoreDriver, GooglePromptDriver, GoogleEmbeddingDriver
from griptape.config import StructureConfig
from griptape.drivers import GoogleEmbeddingDriver, GooglePromptDriver, LocalVectorStoreDriver


@define
class GoogleStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=GooglePromptDriver(model="gemini-pro"),
embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")
),
)
),
kw_only=True,
metadata={"serializable": True},
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda self: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=self.global_drivers.prompt_driver,
vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver),
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver),
json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
class GoogleStructureConfig(StructureConfig):
prompt_driver = GooglePromptDriver(model="gemini-pro")
embedding_driver = GoogleEmbeddingDriver(model="models/embedding-001")
vector_store_driver = LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001"))
53 changes: 8 additions & 45 deletions griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from attrs import Factory, define, field
from attrs import define

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.config import StructureConfig
from griptape.drivers import (
LocalVectorStoreDriver,
OpenAiChatPromptDriver,
Expand All @@ -20,37 +11,9 @@


@define
class OpenAiStructureConfig(BaseStructureConfig):
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(
lambda: StructureGlobalDriversConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"),
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")
),
)
),
kw_only=True,
metadata={"serializable": True},
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda self: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=self.global_drivers.prompt_driver,
vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver),
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver),
json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
class OpenAiStructureConfig(StructureConfig):
prompt_driver = OpenAiChatPromptDriver(model="gpt-4")
image_generation_driver = OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512")
image_query_driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview")
embedding_driver = OpenAiEmbeddingDriver(model="text-embedding-3-small")
vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"))
Loading

0 comments on commit e7d5fb8

Please sign in to comment.