Skip to content

Commit

Permalink
ResponseRagStage and PromptResponseRagModule updates (#1056)
Browse files Browse the repository at this point in the history
Co-authored-by: Collin Dutter <[email protected]>
  • Loading branch information
vasinov and collindutter authored Aug 12, 2024
1 parent e42cb91 commit f77d8e8
Show file tree
Hide file tree
Showing 35 changed files with 113 additions and 183 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Method `try_find_task` to `Structure`.
- `TranslateQueryRagModule` `RagEngine` module for translating input queries.
- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events.
- Unique name generation for all `RagEngine` modules.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
- **BREAKING**: Removed `EventPublisherMixin`.
- **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly.
- **BREAKING**: Removed before and after response modules from `ResponseRagStage`.
- **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`.
- `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible.

## [0.29.0] - 2024-07-30
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/query_webpage_astra_db_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_pdf_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)
vector_store_tool = RagClient(
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_webpage_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)

Expand Down
23 changes: 17 additions & 6 deletions docs/griptape-framework/engines/src/rag_engines_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule
from griptape.engines.rag.stages import QueryRagStage, ResponseRagStage, RetrievalRagStage
from griptape.loaders import WebLoader
from griptape.rules import Rule, Ruleset

prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0)

vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())

artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai")

if isinstance(artifacts, ErrorArtifact):
raise ValueError(artifacts.value)
raise Exception(artifacts.value)

vector_store.upsert_text_artifacts({"griptape": artifacts})
vector_store.upsert_text_artifacts(
{
"griptape": artifacts,
}
)

rag_engine = RagEngine(
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="English")]),
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]),
retrieval_stage=RetrievalRagStage(
max_chunks=5,
retrieval_modules=[
Expand All @@ -25,12 +30,18 @@
)
],
),
response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)),
response_stage=ResponseRagStage(
response_modules=[
PromptResponseRagModule(
prompt_driver=prompt_driver, rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])]
)
]
),
)

rag_context = RagContext(
query="¿Qué ofrecen los servicios en la nube de Griptape?",
module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}},
)

print(rag_engine.process(rag_context).output.to_text())
print(rag_engine.process(rag_context).outputs[0].to_text())
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/task_memory_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
retrieval_rag_module_name="VectorStoreRetrievalRagModule",
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/tasks_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
)
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-tools/official-tools/src/rag_client_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
)
Expand Down
4 changes: 0 additions & 4 deletions griptape/engines/rag/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from .response.base_after_response_rag_module import BaseAfterResponseRagModule
from .response.base_response_rag_module import BaseResponseRagModule
from .response.prompt_response_rag_module import PromptResponseRagModule
from .response.rulesets_before_response_rag_module import RulesetsBeforeResponseRagModule
from .response.metadata_before_response_rag_module import MetadataBeforeResponseRagModule
from .response.text_chunks_response_rag_module import TextChunksResponseRagModule
from .response.footnote_prompt_response_rag_module import FootnotePromptResponseRagModule

Expand All @@ -28,8 +26,6 @@
"BaseAfterResponseRagModule",
"BaseResponseRagModule",
"PromptResponseRagModule",
"RulesetsBeforeResponseRagModule",
"MetadataBeforeResponseRagModule",
"TextChunksResponseRagModule",
"FootnotePromptResponseRagModule",
]
5 changes: 4 additions & 1 deletion griptape/engines/rag/modules/base_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import uuid
from abc import ABC
from concurrent import futures
from typing import TYPE_CHECKING, Any, Callable, Optional
Expand All @@ -14,7 +15,9 @@

@define(kw_only=True)
class BaseRagModule(ABC):
name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True)
name: str = field(
default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True
)
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from attrs import define

from griptape.artifacts import BaseArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseRagModule


@define(kw_only=True)
class BaseResponseRagModule(BaseRagModule, ABC):
@abstractmethod
def run(self, context: RagContext) -> RagContext: ...
def run(self, context: RagContext) -> BaseArtifact: ...

This file was deleted.

29 changes: 17 additions & 12 deletions griptape/engines/rag/modules/response/prompt_response_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional

from attrs import Factory, define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.engines.rag.modules import BaseResponseRagModule
from griptape.mixins import RuleMixin
from griptape.utils import J2

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.drivers import BasePromptDriver
from griptape.engines.rag import RagContext


@define(kw_only=True)
class PromptResponseRagModule(BaseResponseRagModule):
answer_token_offset: int = field(default=400)
class PromptResponseRagModule(BaseResponseRagModule, RuleMixin):
prompt_driver: BasePromptDriver = field()
answer_token_offset: int = field(default=400)
metadata: Optional[str] = field(default=None)
generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(
default=Factory(lambda self: self.default_system_template_generator, takes_self=True),
)

def run(self, context: RagContext) -> RagContext:
def run(self, context: RagContext) -> BaseArtifact:
query = context.query
tokenizer = self.prompt_driver.tokenizer
included_chunks = []
Expand All @@ -45,15 +48,17 @@ def run(self, context: RagContext) -> RagContext:
output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact()

if isinstance(output, TextArtifact):
context.output = output
return output
else:
raise ValueError("Prompt driver did not return a TextArtifact")

return context

def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
return J2("engines/rag/modules/response/prompt/system.j2").render(
text_chunks=[c.to_text() for c in artifacts],
before_system_prompt="\n\n".join(context.before_query),
after_system_prompt="\n\n".join(context.after_query),
)
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

if len(self.all_rulesets) > 0:
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)

if self.metadata is not None:
params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)

return J2("engines/rag/modules/response/prompt/system.j2").render(**params)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from attrs import define

from griptape.artifacts import ListArtifact
from griptape.artifacts import BaseArtifact, ListArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseResponseRagModule


@define(kw_only=True)
class TextChunksResponseRagModule(BaseResponseRagModule):
def run(self, context: RagContext) -> RagContext:
context.output = ListArtifact(context.text_chunks)

return context
def run(self, context: RagContext) -> BaseArtifact:
return ListArtifact(context.text_chunks)
6 changes: 3 additions & 3 deletions griptape/engines/rag/rag_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define, field

Expand All @@ -22,15 +22,15 @@ class RagContext(SerializableMixin):
before_query: An optional list of strings to add before the query in response modules.
after_query: An optional list of strings to add after the query in response modules.
text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage.
output: Final output from the response stage.
outputs: List of outputs from the response stage.
"""

query: str = field(metadata={"serializable": True})
module_configs: dict[str, dict] = field(factory=dict, metadata={"serializable": True})
before_query: list[str] = field(factory=list, metadata={"serializable": True})
after_query: list[str] = field(factory=list, metadata={"serializable": True})
text_chunks: list[TextArtifact] = field(factory=list, metadata={"serializable": True})
output: Optional[BaseArtifact] = field(default=None, metadata={"serializable": True})
outputs: list[BaseArtifact] = field(factory=list, metadata={"serializable": True})

def get_references(self) -> list[Reference]:
return utils.references_from_artifacts(self.text_chunks)
2 changes: 1 addition & 1 deletion griptape/engines/rag/stages/query_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def modules(self) -> Sequence[BaseRagModule]:
return self.query_modules

def run(self, context: RagContext) -> RagContext:
logging.info("QueryStage: running %s query generation modules sequentially", len(self.query_modules))
logging.info("QueryRagStage: running %s query generation modules sequentially", len(self.query_modules))

[qm.run(context) for qm in self.query_modules]

Expand Down
28 changes: 7 additions & 21 deletions griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,35 @@

from attrs import define, field

from griptape import utils
from griptape.engines.rag.stages import BaseRagStage

if TYPE_CHECKING:
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import (
BaseAfterResponseRagModule,
BaseBeforeResponseRagModule,
BaseRagModule,
BaseResponseRagModule,
)


@define(kw_only=True)
class ResponseRagStage(BaseRagStage):
before_response_modules: list[BaseBeforeResponseRagModule] = field(factory=list)
response_module: BaseResponseRagModule = field()
after_response_modules: list[BaseAfterResponseRagModule] = field(factory=list)
response_modules: list[BaseResponseRagModule] = field()

@property
def modules(self) -> list[BaseRagModule]:
ms = []

ms.extend(self.before_response_modules)
ms.extend(self.after_response_modules)

if self.response_module is not None:
ms.append(self.response_module)
ms.extend(self.response_modules)

return ms

def run(self, context: RagContext) -> RagContext:
logging.info("GenerationStage: running %s before modules sequentially", len(self.before_response_modules))

for generator in self.before_response_modules:
context = generator.run(context)

logging.info("GenerationStage: running generation module")

context = self.response_module.run(context)
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules))

logging.info("GenerationStage: running %s after modules sequentially", len(self.after_response_modules))
with self.futures_executor_fn() as executor:
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules])

for generator in self.after_response_modules:
context = generator.run(context)
context.outputs = results

return context
Loading

0 comments on commit f77d8e8

Please sign in to comment.