diff --git a/wren-ai-service/src/pipelines/indexing/table_description.py b/wren-ai-service/src/pipelines/indexing/table_description.py index 2255bd66c..eb5a23331 100644 --- a/wren-ai-service/src/pipelines/indexing/table_description.py +++ b/wren-ai-service/src/pipelines/indexing/table_description.py @@ -31,6 +31,7 @@ def _additional_meta() -> Dict[str, Any]: "id": str(uuid.uuid4()), "meta": { "type": "TABLE_DESCRIPTION", + "name": chunk["name"], **_additional_meta(), }, "content": str(chunk), diff --git a/wren-ai-service/src/pipelines/retrieval/retrieval.py b/wren-ai-service/src/pipelines/retrieval/retrieval.py index 584fffc3f..003f1bc54 100644 --- a/wren-ai-service/src/pipelines/retrieval/retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/retrieval.py @@ -224,20 +224,36 @@ def check_using_db_schemas_without_pruning( for table_schema in construct_db_schemas: if table_schema["type"] == "TABLE": retrieval_results.append( - build_table_ddl( - table_schema, - ) + { + "table_name": table_schema["name"], + "table_ddl": build_table_ddl( + table_schema, + ), + } ) for document in dbschema_retrieval: content = ast.literal_eval(document.content) if content["type"] == "METRIC": - retrieval_results.append(_build_metric_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_metric_ddl(content), + } + ) elif content["type"] == "VIEW": - retrieval_results.append(_build_view_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_view_ddl(content), + } + ) - _token_count = len(encoding.encode(" ".join(retrieval_results))) + table_ddls = [ + retrieval_result["table_ddl"] for retrieval_result in retrieval_results + ] + _token_count = len(encoding.encode(" ".join(table_ddls))) if _token_count > 100_000 or not allow_using_db_schemas_without_pruning: return { "db_schemas": [], @@ -296,7 +312,7 @@ def construct_retrieval_results( filter_columns_in_tables: dict, construct_db_schemas: list[dict], dbschema_retrieval: list[Document], -) -> list[str]: +) -> list[dict]: if filter_columns_in_tables: columns_and_tables_needed = orjson.loads( filter_columns_in_tables["replies"][0] @@ -314,13 +330,18 @@ def construct_retrieval_results( for table_schema in construct_db_schemas: if table_schema["type"] == "TABLE" and table_schema["name"] in tables: retrieval_results.append( - build_table_ddl( - table_schema, - columns=set( - columns_and_tables_needed[table_schema["name"]]["columns"] + { + "table_name": table_schema["name"], + "table_ddl": build_table_ddl( + table_schema, + columns=set( + columns_and_tables_needed[table_schema["name"]][ + "columns" + ] + ), + tables=tables, ), - tables=tables, - ) + } ) for document in dbschema_retrieval: @@ -328,9 +349,19 @@ def construct_retrieval_results( content = ast.literal_eval(document.content) if content["type"] == "METRIC": - retrieval_results.append(_build_metric_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_metric_ddl(content), + } + ) elif content["type"] == "VIEW": - retrieval_results.append(_build_view_ddl(content)) + retrieval_results.append( + { + "table_name": content["name"], + "table_ddl": _build_view_ddl(content), + } + ) else: retrieval_results = check_using_db_schemas_without_pruning["db_schemas"] diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 6bf26e7fa..db5e9f10d 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -92,6 +92,7 @@ class AskResultResponse(BaseModel): rephrased_question: Optional[str] = None intent_reasoning: Optional[str] = None type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None + retrieval_response: Optional[List[str]] = None response: Optional[List[AskResult]] = None error: Optional[AskError] = None @@ -243,6 +244,8 @@ async def ask( id=ask_request.project_id, ) documents = retrieval_result.get("construct_retrieval_results", []) + table_names = [document.get("table_name") for document in documents] + table_ddls = [document.get("table_ddl") for document in documents] if not documents: logger.exception(f"ask pipeline - NO_RELEVANT_DATA: {user_query}") @@ -267,6 +270,7 @@ async def ask( type="TEXT_TO_SQL", rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieval_response=table_names, ) sql_samples = ( @@ -322,7 +326,7 @@ async def ask( sql_correction_results = await self._pipelines[ "sql_correction" ].run( - contexts=documents, + contexts=table_ddls, invalid_generation_results=failed_dry_run_results, project_id=ask_request.project_id, ) @@ -348,6 +352,7 @@ async def ask( response=api_results, rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieval_response=table_names, ) results["ask_result"] = api_results results["metadata"]["type"] = "TEXT_TO_SQL" @@ -363,6 +368,7 @@ async def ask( ), rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, + retrieval_response=table_names, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" results["metadata"]["type"] = "TEXT_TO_SQL" diff --git a/wren-ai-service/src/web/v1/services/question_recommendation.py b/wren-ai-service/src/web/v1/services/question_recommendation.py index 3af7f603c..1284af3d2 100644 --- a/wren-ai-service/src/web/v1/services/question_recommendation.py +++ b/wren-ai-service/src/web/v1/services/question_recommendation.py @@ -74,9 +74,10 @@ async def _validate_question( id=project_id, ) documents = retrieval_result.get("construct_retrieval_results", []) + table_ddls = [document.get("table_ddl") for document in documents] generated_sql = await self._pipelines["sql_generation"].run( query=candidate["question"], - contexts=documents, + contexts=table_ddls, exclude=[], configuration=Configuration(), project_id=project_id, diff --git a/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py b/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py index fbe1cdb46..bc7328c1a 100644 --- a/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py +++ b/wren-ai-service/tests/pytest/pipelines/indexing/test_table_description.py @@ -37,7 +37,7 @@ def test_single_table_description(): assert len(actual["documents"]) == 1 document: Document = actual["documents"][0] - assert document.meta == {"type": "TABLE_DESCRIPTION"} + assert document.meta == {"type": "TABLE_DESCRIPTION", "name": "user"} assert document.content == str( { "name": "user", @@ -71,6 +71,7 @@ def test_multiple_table_descriptions(): document_1: Document = actual["documents"][0] assert document_1.meta == { "type": "TABLE_DESCRIPTION", + "name": "user", } assert document_1.content == str( { @@ -81,7 +82,7 @@ def test_multiple_table_descriptions(): ) document_2: Document = actual["documents"][1] - assert document_2.meta == {"type": "TABLE_DESCRIPTION"} + assert document_2.meta == {"type": "TABLE_DESCRIPTION", "name": "order"} assert document_2.content == str( { "name": "order", @@ -121,7 +122,7 @@ def test_table_description_missing_description(): assert len(actual["documents"]) == 1 document: Document = actual["documents"][0] - assert document.meta == {"type": "TABLE_DESCRIPTION"} + assert document.meta == {"type": "TABLE_DESCRIPTION", "name": "user"} assert document.content == str( {"name": "user", "mdl_type": "MODEL", "description": ""} )