Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResponseRagStage and PromptResponseRagModule updates #1056

Merged
merged 15 commits into from
Aug 12, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Method `try_find_task` to `Structure`.
- `TranslateQueryRagModule` `RagEngine` module for translating input queries.
- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events.
- Unique name generation for all `RagEngine` modules.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
- **BREAKING**: Removed `EventPublisherMixin`.
- **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly.
**BREAKING**: Removed before and after response modules from `ResponseRagStage`.
**BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing hypens

- `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible.

## [0.29.0] - 2024-07-30
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/query_webpage_astra_db_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_pdf_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)
vector_store_tool = RagClient(
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_webpage_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
)

Expand Down
24 changes: 15 additions & 9 deletions docs/griptape-framework/engines/src/rag_engines_1.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from griptape.artifacts import ErrorArtifact
from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver
from griptape.engines.rag import RagContext, RagEngine
from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule
from griptape.engines.rag.stages import QueryRagStage, ResponseRagStage, RetrievalRagStage
from griptape.loaders import WebLoader
from griptape.rules import Rule, Ruleset

prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0)

vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())

artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai")
if isinstance(artifacts, ErrorArtifact):
raise ValueError(artifacts.value)

vector_store.upsert_text_artifacts({"griptape": artifacts})
vector_store.upsert_text_artifacts(
{
"griptape": WebLoader(max_tokens=500).load("https://www.griptape.ai"),
}
)

Copy link
Member

@collindutter collindutter Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this pass type checks?

Edit: it probably will not: #1057

rag_engine = RagEngine(
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="English")]),
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]),
retrieval_stage=RetrievalRagStage(
max_chunks=5,
retrieval_modules=[
Expand All @@ -25,12 +25,18 @@
)
],
),
response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)),
response_stage=ResponseRagStage(
response_modules=[
PromptResponseRagModule(
prompt_driver=prompt_driver, rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])]
)
]
),
)

rag_context = RagContext(
query="¿Qué ofrecen los servicios en la nube de Griptape?",
module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}},
)

print(rag_engine.process(rag_context).output.to_text())
print(rag_engine.process(rag_context).outputs[0].to_text())
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/task_memory_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
retrieval_rag_module_name="VectorStoreRetrievalRagModule",
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/tasks_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
)
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-tools/official-tools/src/rag_client_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))
response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))]
),
),
)
Expand Down
4 changes: 0 additions & 4 deletions griptape/engines/rag/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from .response.base_after_response_rag_module import BaseAfterResponseRagModule
from .response.base_response_rag_module import BaseResponseRagModule
from .response.prompt_response_rag_module import PromptResponseRagModule
from .response.rulesets_before_response_rag_module import RulesetsBeforeResponseRagModule
from .response.metadata_before_response_rag_module import MetadataBeforeResponseRagModule
from .response.text_chunks_response_rag_module import TextChunksResponseRagModule
from .response.footnote_prompt_response_rag_module import FootnotePromptResponseRagModule

Expand All @@ -28,8 +26,6 @@
"BaseAfterResponseRagModule",
"BaseResponseRagModule",
"PromptResponseRagModule",
"RulesetsBeforeResponseRagModule",
"MetadataBeforeResponseRagModule",
"TextChunksResponseRagModule",
"FootnotePromptResponseRagModule",
]
5 changes: 4 additions & 1 deletion griptape/engines/rag/modules/base_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import uuid
from abc import ABC
from concurrent import futures
from typing import TYPE_CHECKING, Any, Callable, Optional
Expand All @@ -14,7 +15,9 @@

@define(kw_only=True)
class BaseRagModule(ABC):
name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True)
name: str = field(
default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True
)
Comment on lines +18 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does name uniqueness get us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much easier to add multiple modules of the same type without having to explicitly define names. May be that's something we add to tools as well?

futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from attrs import define

from griptape.artifacts import BaseArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseRagModule


@define(kw_only=True)
class BaseResponseRagModule(BaseRagModule, ABC):
@abstractmethod
def run(self, context: RagContext) -> RagContext: ...
def run(self, context: RagContext) -> BaseArtifact: ...

This file was deleted.

28 changes: 17 additions & 11 deletions griptape/engines/rag/modules/response/prompt_response_rag_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional

from attrs import Factory, define, field

Expand All @@ -9,19 +9,23 @@
from griptape.utils import J2

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.drivers import BasePromptDriver
from griptape.engines.rag import RagContext
from griptape.rules import Ruleset


@define(kw_only=True)
class PromptResponseRagModule(BaseResponseRagModule):
answer_token_offset: int = field(default=400)
prompt_driver: BasePromptDriver = field()
answer_token_offset: int = field(default=400)
rulesets: list[Ruleset] = field(factory=list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should implement RuleMixin instead of defining itself.

metadata: Optional[str] = field(default=None)
generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(
default=Factory(lambda self: self.default_system_template_generator, takes_self=True),
)

def run(self, context: RagContext) -> RagContext:
def run(self, context: RagContext) -> BaseArtifact:
query = context.query
tokenizer = self.prompt_driver.tokenizer
included_chunks = []
Expand All @@ -45,15 +49,17 @@ def run(self, context: RagContext) -> RagContext:
output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact()

if isinstance(output, TextArtifact):
context.output = output
return output
else:
raise ValueError("Prompt driver did not return a TextArtifact")

return context

def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
return J2("engines/rag/modules/response/prompt/system.j2").render(
text_chunks=[c.to_text() for c in artifacts],
before_system_prompt="\n\n".join(context.before_query),
after_system_prompt="\n\n".join(context.after_query),
)
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

if len(self.rulesets) > 0:
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After implementing RuleMixin, replace self.rulesets with self.all_rulesets.


if self.metadata is not None:
params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)

return J2("engines/rag/modules/response/prompt/system.j2").render(**params)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from attrs import define

from griptape.artifacts import ListArtifact
from griptape.artifacts import BaseArtifact, ListArtifact
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import BaseResponseRagModule


@define(kw_only=True)
class TextChunksResponseRagModule(BaseResponseRagModule):
def run(self, context: RagContext) -> RagContext:
context.output = ListArtifact(context.text_chunks)

return context
def run(self, context: RagContext) -> BaseArtifact:
return ListArtifact(context.text_chunks)
6 changes: 3 additions & 3 deletions griptape/engines/rag/rag_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from attrs import define, field

Expand All @@ -22,15 +22,15 @@ class RagContext(SerializableMixin):
before_query: An optional list of strings to add before the query in response modules.
after_query: An optional list of strings to add after the query in response modules.
text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage.
output: Final output from the response stage.
outputs: List of outputs from the response stage.
"""

query: str = field(metadata={"serializable": True})
module_configs: dict[str, dict] = field(factory=dict, metadata={"serializable": True})
before_query: list[str] = field(factory=list, metadata={"serializable": True})
after_query: list[str] = field(factory=list, metadata={"serializable": True})
text_chunks: list[TextArtifact] = field(factory=list, metadata={"serializable": True})
output: Optional[BaseArtifact] = field(default=None, metadata={"serializable": True})
outputs: list[BaseArtifact] = field(factory=list, metadata={"serializable": True})

def get_references(self) -> list[Reference]:
return utils.references_from_artifacts(self.text_chunks)
2 changes: 1 addition & 1 deletion griptape/engines/rag/stages/query_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
return self.query_modules

def run(self, context: RagContext) -> RagContext:
logging.info("QueryStage: running %s query generation modules sequentially", len(self.query_modules))
logging.info("QueryRagStage: running %s query generation modules sequentially", len(self.query_modules))

Check warning on line 26 in griptape/engines/rag/stages/query_rag_stage.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/rag/stages/query_rag_stage.py#L26

Added line #L26 was not covered by tests

[qm.run(context) for qm in self.query_modules]

Expand Down
28 changes: 7 additions & 21 deletions griptape/engines/rag/stages/response_rag_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,35 @@

from attrs import define, field

from griptape import utils
from griptape.engines.rag.stages import BaseRagStage

if TYPE_CHECKING:
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import (
BaseAfterResponseRagModule,
BaseBeforeResponseRagModule,
BaseRagModule,
BaseResponseRagModule,
)


@define(kw_only=True)
class ResponseRagStage(BaseRagStage):
before_response_modules: list[BaseBeforeResponseRagModule] = field(factory=list)
response_module: BaseResponseRagModule = field()
after_response_modules: list[BaseAfterResponseRagModule] = field(factory=list)
response_modules: list[BaseResponseRagModule] = field()

@property
def modules(self) -> list[BaseRagModule]:
ms = []

ms.extend(self.before_response_modules)
ms.extend(self.after_response_modules)

if self.response_module is not None:
ms.append(self.response_module)
ms.extend(self.response_modules)

return ms
Comment on lines 24 to 29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this property?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do because we test for module name uniqueness by using this property.


def run(self, context: RagContext) -> RagContext:
logging.info("GenerationStage: running %s before modules sequentially", len(self.before_response_modules))

for generator in self.before_response_modules:
context = generator.run(context)

logging.info("GenerationStage: running generation module")

context = self.response_module.run(context)
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules))

logging.info("GenerationStage: running %s after modules sequentially", len(self.after_response_modules))
with self.futures_executor_fn() as executor:
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules])

for generator in self.after_response_modules:
context = generator.run(context)
context.outputs = results

return context
Loading
Loading