-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ResponseRagStage
and PromptResponseRagModule
updates
#1056
Changes from 10 commits
044e870
8b97745
6f59f87
ebc6949
481f95a
6b602a9
d2f9c3b
ddc398c
7654f86
e45c2b5
8fd947a
d5a507f
f6e5201
4db3d30
8395cf8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,22 @@ | ||
from griptape.artifacts import ErrorArtifact | ||
from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver | ||
from griptape.engines.rag import RagContext, RagEngine | ||
from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule | ||
from griptape.engines.rag.stages import QueryRagStage, ResponseRagStage, RetrievalRagStage | ||
from griptape.loaders import WebLoader | ||
from griptape.rules import Rule, Ruleset | ||
|
||
prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) | ||
|
||
vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) | ||
|
||
artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") | ||
if isinstance(artifacts, ErrorArtifact): | ||
raise ValueError(artifacts.value) | ||
|
||
vector_store.upsert_text_artifacts({"griptape": artifacts}) | ||
vector_store.upsert_text_artifacts( | ||
{ | ||
"griptape": WebLoader(max_tokens=500).load("https://www.griptape.ai"), | ||
} | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this pass type checks? Edit: it probably will not: #1057 |
||
rag_engine = RagEngine( | ||
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="English")]), | ||
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]), | ||
retrieval_stage=RetrievalRagStage( | ||
max_chunks=5, | ||
retrieval_modules=[ | ||
|
@@ -25,12 +25,18 @@ | |
) | ||
], | ||
), | ||
response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)), | ||
response_stage=ResponseRagStage( | ||
response_modules=[ | ||
PromptResponseRagModule( | ||
prompt_driver=prompt_driver, rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])] | ||
) | ||
] | ||
), | ||
) | ||
|
||
rag_context = RagContext( | ||
query="¿Qué ofrecen los servicios en la nube de Griptape?", | ||
module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}}, | ||
) | ||
|
||
print(rag_engine.process(rag_context).output.to_text()) | ||
print(rag_engine.process(rag_context).outputs[0].to_text()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import uuid | ||
from abc import ABC | ||
from concurrent import futures | ||
from typing import TYPE_CHECKING, Any, Callable, Optional | ||
|
@@ -14,7 +15,9 @@ | |
|
||
@define(kw_only=True) | ||
class BaseRagModule(ABC): | ||
name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) | ||
name: str = field( | ||
default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True | ||
) | ||
Comment on lines
+18
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does name uniqueness get us? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Much easier to add multiple modules of the same type without having to explicitly define names. May be that's something we add to tools as well? |
||
futures_executor_fn: Callable[[], futures.Executor] = field( | ||
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), | ||
) | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Callable | ||
from typing import TYPE_CHECKING, Any, Callable, Optional | ||
|
||
from attrs import Factory, define, field | ||
|
||
|
@@ -9,19 +9,23 @@ | |
from griptape.utils import J2 | ||
|
||
if TYPE_CHECKING: | ||
from griptape.artifacts import BaseArtifact | ||
from griptape.drivers import BasePromptDriver | ||
from griptape.engines.rag import RagContext | ||
from griptape.rules import Ruleset | ||
|
||
|
||
@define(kw_only=True) | ||
class PromptResponseRagModule(BaseResponseRagModule): | ||
answer_token_offset: int = field(default=400) | ||
prompt_driver: BasePromptDriver = field() | ||
answer_token_offset: int = field(default=400) | ||
rulesets: list[Ruleset] = field(factory=list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should implement |
||
metadata: Optional[str] = field(default=None) | ||
generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( | ||
default=Factory(lambda self: self.default_system_template_generator, takes_self=True), | ||
) | ||
|
||
def run(self, context: RagContext) -> RagContext: | ||
def run(self, context: RagContext) -> BaseArtifact: | ||
query = context.query | ||
tokenizer = self.prompt_driver.tokenizer | ||
included_chunks = [] | ||
|
@@ -45,15 +49,17 @@ def run(self, context: RagContext) -> RagContext: | |
output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact() | ||
|
||
if isinstance(output, TextArtifact): | ||
context.output = output | ||
return output | ||
else: | ||
raise ValueError("Prompt driver did not return a TextArtifact") | ||
|
||
return context | ||
|
||
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: | ||
return J2("engines/rag/modules/response/prompt/system.j2").render( | ||
text_chunks=[c.to_text() for c in artifacts], | ||
before_system_prompt="\n\n".join(context.before_query), | ||
after_system_prompt="\n\n".join(context.after_query), | ||
) | ||
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} | ||
|
||
if len(self.rulesets) > 0: | ||
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After implementing |
||
|
||
if self.metadata is not None: | ||
params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) | ||
|
||
return J2("engines/rag/modules/response/prompt/system.j2").render(**params) |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,11 @@ | ||
from attrs import define | ||
|
||
from griptape.artifacts import ListArtifact | ||
from griptape.artifacts import BaseArtifact, ListArtifact | ||
from griptape.engines.rag import RagContext | ||
from griptape.engines.rag.modules import BaseResponseRagModule | ||
|
||
|
||
@define(kw_only=True) | ||
class TextChunksResponseRagModule(BaseResponseRagModule): | ||
def run(self, context: RagContext) -> RagContext: | ||
context.output = ListArtifact(context.text_chunks) | ||
|
||
return context | ||
def run(self, context: RagContext) -> BaseArtifact: | ||
return ListArtifact(context.text_chunks) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,49 +5,35 @@ | |
|
||
from attrs import define, field | ||
|
||
from griptape import utils | ||
from griptape.engines.rag.stages import BaseRagStage | ||
|
||
if TYPE_CHECKING: | ||
from griptape.engines.rag import RagContext | ||
from griptape.engines.rag.modules import ( | ||
BaseAfterResponseRagModule, | ||
BaseBeforeResponseRagModule, | ||
BaseRagModule, | ||
BaseResponseRagModule, | ||
) | ||
|
||
|
||
@define(kw_only=True) | ||
class ResponseRagStage(BaseRagStage): | ||
before_response_modules: list[BaseBeforeResponseRagModule] = field(factory=list) | ||
response_module: BaseResponseRagModule = field() | ||
after_response_modules: list[BaseAfterResponseRagModule] = field(factory=list) | ||
response_modules: list[BaseResponseRagModule] = field() | ||
|
||
@property | ||
def modules(self) -> list[BaseRagModule]: | ||
ms = [] | ||
|
||
ms.extend(self.before_response_modules) | ||
ms.extend(self.after_response_modules) | ||
|
||
if self.response_module is not None: | ||
ms.append(self.response_module) | ||
ms.extend(self.response_modules) | ||
|
||
return ms | ||
Comment on lines
24
to
29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need this property? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do because we test for module name uniqueness by using this property. |
||
|
||
def run(self, context: RagContext) -> RagContext: | ||
logging.info("GenerationStage: running %s before modules sequentially", len(self.before_response_modules)) | ||
|
||
for generator in self.before_response_modules: | ||
context = generator.run(context) | ||
|
||
logging.info("GenerationStage: running generation module") | ||
|
||
context = self.response_module.run(context) | ||
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules)) | ||
|
||
logging.info("GenerationStage: running %s after modules sequentially", len(self.after_response_modules)) | ||
with self.futures_executor_fn() as executor: | ||
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules]) | ||
|
||
for generator in self.after_response_modules: | ||
context = generator.run(context) | ||
context.outputs = results | ||
|
||
return context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing hypens