Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite rag to use dynamic config #162

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions b2b/lambda/rag/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
from aws_lambda_powertools.utilities.parser import parse, ValidationError
from aws_lambda_powertools.utilities.parser import BaseModel

from xayn_rag.rag import run_query, RagType
from xayn_rag.rag import run_query
from xayn_rag.context import (
ConfigContext,
Config,
LlmPlatformValues,
SearchPlatformValues,
EnvKey,
ConfigEnvLoader,
)
from xayn_rag.retrieval import SimpleSearchQuery


nlb_url = os.getenv("NLB_URL")
Expand All @@ -39,6 +36,9 @@

class QuestionRequest(BaseModel):
query: str
filter: str | None
include_properties: bool = True
use_hybrid_search: bool = False


@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST)
Expand All @@ -49,7 +49,12 @@ def lambda_handler(event: dict, context: LambdaContext) -> dict:
@app.post("/rag")
def get_answer():
tenant_id: Optional[str] = app.current_event.request_context.authorizer.principal_id
error_wrap = app.current_event.get_query_string_value(name='error_wrap', default_value='false') == 'true'
error_wrap = (
app.current_event.get_query_string_value(
name="error_wrap", default_value="false"
)
== "true"
)
if not tenant_id:
logger.error(
"request_context.authorizer.principal_id is not set! This lambda must be called with a valid TenantId!"
Expand Down Expand Up @@ -85,23 +90,52 @@ def get_answer():
error_wrap=error_wrap,
)

context = ConfigContext(
config={
Config.LLM_PLATFORM: LlmPlatformValues.HUGGINGFACE,
Config.SEARCH_PLATFORM: SearchPlatformValues.XAYN_INTERNAL,
# TODO move this to ssm
configs = {
"porschedemoe5": {
"type": "EM_GERMAN_RAG",
"config": {
"LLM_PLATFORM": "HUGGINGFACE",
"SEARCH_PLATFORM": "XAYN_INTERNAL",
},
"envs": {
"XAYN_SEARCH_ENDPOINT": nlb_url,
"TENANT_ID": tenant_id,
"HUGGINGFACE_ENDPOINT_TOKEN": llm_bearer_token,
"HUGGINGFACE_ENDPOINT_URL": llm_url,
"USE_TOP_N_RESULTS": 5,
},
},
env_loader=ConfigEnvLoader(
{
EnvKey.XAYN_SEARCH_ENDPOINT: nlb_url,
EnvKey.TENANT_ID: tenant_id,
EnvKey.HUGGINGFACE_ENDPOINT_TOKEN: llm_bearer_token,
EnvKey.HUGGINGFACE_ENDPOINT_URL: llm_url,
EnvKey.USE_TOP_N_RESULTS: 5,
}
),
"legaldemolarge": {
"type": "EM_GERMAN_RAG",
"config": {
"LLM_PLATFORM": "HUGGINGFACE",
"SEARCH_PLATFORM": "XAYN_INTERNAL",
},
"envs": {
"XAYN_SEARCH_ENDPOINT": nlb_url,
"TENANT_ID": tenant_id,
"HUGGINGFACE_ENDPOINT_TOKEN": llm_bearer_token,
"HUGGINGFACE_ENDPOINT_URL": llm_url,
"USE_TOP_N_RESULTS": 5,
},
},
}

config = configs[tenant_id]
context = ConfigContext(
config=config["config"],
env_loader=ConfigEnvLoader(config["envs"]),
)
res = run_query(
query=request.query, context=context, rag_type=RagType.EM_GERMAN_RAG
query=SimpleSearchQuery(
query=request.query,
filter_json=request.filter,
include_properties=request.include_properties,
use_hybrid_search=request.use_hybrid_search,
),
context=context,
rag_type=configs[tenant_id]["type"],
)
return convert_response(res, error_wrap=error_wrap)

Expand Down
2 changes: 1 addition & 1 deletion b2b/lambda/rag/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
xayn_rag @ git+ssh://[email protected]/xaynetwork/xayn_rag@rag_extensions
xayn_rag @ git+ssh://[email protected]/xaynetwork/xayn_rag@6f8838191b0ab2ff2f1e8c6f4bc1bfe00c5a6eb4
Loading