diff --git a/app_utils/chat.py b/app_utils/chat.py index 1addce99..bad1dd45 100644 --- a/app_utils/chat.py +++ b/app_utils/chat.py @@ -9,7 +9,7 @@ API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" -@st.cache_data(ttl=300, show_spinner=False) +@st.cache_data(ttl=60, show_spinner=False) def send_message( _conn: SnowflakeConnection, semantic_model: str, messages: list[dict[str, str]] ) -> Dict[str, Any]: diff --git a/app_utils/shared_utils.py b/app_utils/shared_utils.py index b59fe384..249829d3 100644 --- a/app_utils/shared_utils.py +++ b/app_utils/shared_utils.py @@ -40,7 +40,6 @@ fetch_tables_views_in_schema, fetch_warehouses, fetch_yaml_names_in_stage, - fetch_columns_names_in_table, ) SNOWFLAKE_ACCOUNT = os.environ.get("SNOWFLAKE_ACCOUNT_LOCATOR", "") @@ -126,7 +125,7 @@ def get_snowflake_connection() -> SnowflakeConnection: return get_connector().open_connection(db_name="") -# @st.cache_resource(show_spinner=False) +@st.cache_resource(show_spinner=False) def set_snowpark_session(_conn: Optional[SnowflakeConnection] = None) -> None: """ Creates a snowpark for python session. @@ -201,130 +200,6 @@ def get_available_stages(schema: str) -> List[str]: return fetch_stages_in_schema(get_snowflake_connection(), schema) -@st.cache_resource(show_spinner=False) -def validate_table_schema(table: str, schema: Dict[str, str]) -> bool: - table_schema = fetch_table_schema(get_snowflake_connection(), table) - if set(schema) != set(table_schema): - return False - for col_name, col_type in table_schema.items(): - if not (schema[col_name] in col_type): - return False - return True - - -@st.cache_resource(show_spinner=False) -def validate_table_exist(schema: str, table_name:str) -> bool: - """ - Validate table exist in the Snowflake account. - - Returns: - List[str]: A list of available stages. - """ - table_names = fetch_tables_views_in_schema(get_snowflake_connection(), schema) - table_names = [table.split(".")[2] for table in table_names] - if table_name.upper() in table_names: - return True - return False - - -def schema_selector_container( - db_selector: Dict[str, str], schema_selector: Dict[str, str] -) -> List[str]: - """ - Common component that encapsulates db/schema/table selection for the admin app. - When a db/schema/table is selected, it is saved to the session state for reading elsewhere. - Returns: None - """ - available_schemas = [] - available_tables = [] - - # First, retrieve all databases that the user has access to. - eval_database = st.selectbox( - db_selector["label"], - options=get_available_databases(), - index=None, - key=db_selector["key"], - ) - if eval_database: - # When a valid database is selected, fetch the available schemas in that database. - try: - available_schemas = get_available_schemas(eval_database) - except (ValueError, ProgrammingError): - st.error("Insufficient permissions to read from the selected database.") - st.stop() - - eval_schema = st.selectbox( - schema_selector["label"], - options=available_schemas, - index=None, - key=schema_selector["key"], - format_func=lambda x: format_snowflake_context(x, -1), - ) - if eval_schema: - # When a valid schema is selected, fetch the available tables in that schema. - try: - available_tables = get_available_tables(eval_schema) - except (ValueError, ProgrammingError): - st.error("Insufficient permissions to read from the selected schema.") - st.stop() - - return available_tables - - -def table_selector_container( - db_selector: Dict[str, str], - schema_selector: Dict[str, str], - table_selector: Dict[str, str], -) -> Optional[str]: - """ - Common component that encapsulates db/schema/table selection for the admin app. - When a db/schema/table is selected, it is saved to the session state for reading elsewhere. - Returns: None - """ - available_schemas = [] - available_tables = [] - - # First, retrieve all databases that the user has access to. - eval_database = st.selectbox( - db_selector["label"], - options=get_available_databases(), - index=None, - key=db_selector["key"], - ) - if eval_database: - # When a valid database is selected, fetch the available schemas in that database. - try: - available_schemas = get_available_schemas(eval_database) - except (ValueError, ProgrammingError): - st.error("Insufficient permissions to read from the selected database.") - st.stop() - - eval_schema = st.selectbox( - schema_selector["label"], - options=available_schemas, - index=None, - key=schema_selector["key"], - format_func=lambda x: format_snowflake_context(x, -1), - ) - if eval_schema: - # When a valid schema is selected, fetch the available tables in that schema. - try: - available_tables = get_available_tables(eval_schema) - except (ValueError, ProgrammingError): - st.error("Insufficient permissions to read from the selected schema.") - st.stop() - - tables = st.selectbox( - table_selector["label"], - options=available_tables, - index=None, - key=table_selector["key"], - format_func=lambda x: format_snowflake_context(x, -1), - ) - - return tables - - @st.cache_resource(show_spinner=False) def validate_table_schema(table: str, schema: Dict[str, str]) -> bool: table_schema = fetch_table_schema(get_snowflake_connection(), table) diff --git a/journeys/builder.py b/journeys/builder.py index 453a897a..c7fe7bf2 100644 --- a/journeys/builder.py +++ b/journeys/builder.py @@ -42,12 +42,7 @@ class CortexSearchConfig: warehouse_name: str target_lag: str -# if not st.session_state["semantic_model_name"]: -# st.error("Please provide a name for your semantic model.") -# elif not st.session_state["selected_tables"]: -# st.error("Please select at least one table to proceed.") -# else: -# st.session_state["table_selector_submitted"] = True + def init_session_state() -> None: default_state = { "build_semantic_model": False, diff --git a/journeys/iteration.py b/journeys/iteration.py index 67312a28..ac9d73c7 100644 --- a/journeys/iteration.py +++ b/journeys/iteration.py @@ -7,11 +7,9 @@ import streamlit as st from snowflake.connector import ProgrammingError, SnowflakeConnection from streamlit import config -from snowflake.connector.pandas_tools import write_pandas from streamlit.delta_generator import DeltaGenerator from streamlit_extras.row import row from streamlit_extras.stylable_container import stylable_container -from semantic_model_generator.snowflake_utils.snowflake_connector import fetch_table from app_utils.chat import send_message from app_utils.shared_utils import ( @@ -42,46 +40,6 @@ from semantic_model_generator.protos import semantic_model_pb2 from semantic_model_generator.validate_model import validate -EVALUATION_TABLE_SCHEMA = { - "ID": "VARCHAR", - "QUERY": "VARCHAR", - "GOLD_SQL": "VARCHAR", -} -RESULTS_TABLE_SCHEMA = { - "TIMESTAMP": "DATETIME", - "ID": "VARCHAR", - "QUERY": "VARCHAR", - "ANALYST_TEXT": "VARCHAR", - "ANALYST_SQL": "VARCHAR", - "ANALYST_RESULT": "VARCHAR", - "GOLD_SQL": "VARCHAR", - "GOLD_RESULT": "VARCHAR", - "CORRECT": "BOOLEAN", - "EXPLANATION": "VARCHAR", - "MODEL_HASH": "VARCHAR", -} - -LLM_JUDGE_PROMPT_TEMPLATE = """\ -[INST] Your task is to determine whether the two given dataframes are -equivalent semantically in the context of a question. You should attempt to -answer the given question by using the data in each dataframe. If the two -answers are equivalent, those two dataframes are considered equivalent. -Otherwise, they are not equivalent. Please also provide your reasoning. -If they are equivalent, output "REASON: . ANSWER: true". If they are -not equivalent, output "REASON: . ANSWER: false". - -### QUESTION: {input_question} - -* DATAFRAME 1: -{frame1_str} - -* DATAFRAME 2: -{frame2_str} - -Are the two dataframes equivalent? -OUTPUT: -[/INST] """ - # Set minCachedMessageSize to 500 MB to disable forward message cache: # st.set_config would trigger an error, only the set_config from config module works config.set_option("global.minCachedMessageSize", 500 * 1e6) @@ -409,205 +367,6 @@ def chat_and_edit_vqr(_conn: SnowflakeConnection) -> None: st.session_state.active_suggestion = None -def clear_evaluation_data() -> None: - session_states = ( - "eval_table_frame", - "eval_table_hash", - "selected_eval_database", - "selected_eval_schema", - "selected_eval_table", - "selected_results_eval_database", - "selected_results_eval_new_table", - "selected_results_eval_new_table_no_schema", - "selected_results_eval_old_table", - "selected_results_eval_schema", - "use_existing_table", - "eval_timestamp", - ) - for feature in session_states: - if feature in st.session_state: - del st.session_state[feature] - - -def validate_table_columns(param, evaluation_table_columns): - pass - - -@st.experimental_dialog("Evaluation Data", width="large") -def evaluation_data_dialog() -> None: - evaluation_table_columns = ["ID", "QUERY", "GOLD_SQL"] - st.markdown("Please select evaluation table") - table_selector_container( - db_selector={"key": "selected_eval_database", "label": "Eval database"}, - schema_selector={"key": "selected_eval_schema", "label": "Eval schema"}, - table_selector={"key": "selected_eval_table", "label": "Eval table"}, - ) - if st.button("Use Table"): - if ( - not st.session_state["selected_eval_database"] - or not st.session_state["selected_eval_schema"] - or not st.session_state["selected_eval_table"] - ): - st.error("Please fill in all fields.") - return - - if not validate_table_columns(st.session_state["selected_eval_table"], evaluation_table_columns): - st.error("Table must have columns {evaluation_table_columns} to be used in Evaluation") - return - - st.session_state["eval_table"] = SnowflakeTable( - table_database=st.session_state["selected_eval_database"], - table_schema=st.session_state["selected_eval_schema"], - table_name=st.session_state["selected_eval_table"], - ) - st.session_state["eval_table_hash"] = get_table_hash( - conn=get_snowflake_connection(), table_fqn=st.session_state.eval_table.table_name - ) - eval_table_frame = fetch_table( - conn=get_snowflake_connection(), table_fqn=st.session_state.eval_table.table_name - ) - st.session_state["eval_table_frame"] = eval_table_frame.set_index("ID") - - st.rerun() - - - if not eval_results_existing_table: - schema_selector_container( - db_selector={"key": "selected_results_eval_database","label":"Results database"}, - schema_selector={"key": "selected_results_eval_schema","label":"Results schema"},) - - original_new_table_name = st.text_input( - key="selected_results_eval_new_table_no_schema", - label="Enter the table name to upload evaluation results", - ) - if st.button("Create Table"): - if ( - not st.session_state["selected_results_eval_database"] - or not st.session_state["selected_results_eval_schema"] - or not new_table_name - ): - st.error("Please fill in all fields.") - return - - if ( - st.session_state["selected_results_eval_database"] - and st.session_state["selected_results_eval_schema"] - and validate_table_exist( - st.session_state["selected_results_eval_schema"], new_table_name - ) - ): - st.error("Table already exists") - return - - - with st.spinner("Creating table..."): - success = create_table_in_schema( - conn=get_snowflake_connection(), - schema_name=st.session_state["selected_results_eval_schema"], - table_name=new_table_name, - columns_schema=[ - f"{k} {v}" for k, v in results_table_columns.items() - ], - ) - if success: - st.success(f"Table {new_table_name} created successfully!") - else: - st.error(f"Failed to create table {new_table_name}") - return - - fqn_table_name = ".".join([st.session_state["selected_results_eval_schema"],new_table_name.upper()]) - - st.session_state["eval_results_table"] = SnowflakeTable( - table_database=st.session_state["selected_results_eval_database"], - table_schema=st.session_state["selected_results_eval_schema"], - table_name=fqn_table_name, - ) - - st.rerun() - - else: - table_selector_container( - db_selector={ - "key": "selected_results_eval_database", - "label": "Results database", - }, - schema_selector={ - "key": "selected_results_eval_schema", - "label": "Results schema", - }, - table_selector={ - "key": "selected_results_eval_old_table", - "label": "Results table", - }, - ) - - st.divider() - - if st.button("Use Tables"): - st.session_state["selected_results_eval_table"] = st.session_state.get( - "selected_results_eval_new_table" - ) or st.session_state.get("selected_results_eval_old_table") - - if ( - not st.session_state["selected_eval_database"] - or not st.session_state["selected_eval_schema"] - or not st.session_state["selected_eval_table"] - or not st.session_state["selected_results_eval_database"] - or not st.session_state["selected_results_eval_schema"] - or not st.session_state["selected_results_eval_table"] - ): - st.error("Please fill in all fields.") - return - - if not validate_table_schema( - table=st.session_state["selected_eval_table"], - schema=EVALUATION_TABLE_SCHEMA, - ): - st.error(f"Evaluation table must have schema {EVALUATION_TABLE_SCHEMA}.") - return - - if eval_results_existing_table: - if not validate_table_schema( - table=st.session_state["selected_results_eval_old_table"], - schema=RESULTS_TABLE_SCHEMA, - ): - st.error( - f"Evaluation result table must have schema {RESULTS_TABLE_SCHEMA}." - ) - return - - if not validate_table_columns(st.session_state["selected_results_eval_table"], tuple(results_table_columns.keys())): - st.error(f"Table must have columns {list(results_table_columns.keys())}.") - return - - with st.spinner("Creating table..."): - success = create_table_in_schema( - conn=get_snowflake_connection(), - table_fqn=st.session_state["selected_results_eval_new_table"], - columns_schema=RESULTS_TABLE_SCHEMA, - ) - if success: - st.success( - f'Table {st.session_state["selected_results_eval_new_table"]} created successfully!' - ) - else: - st.error( - f'Failed to create table {st.session_state["selected_results_eval_new_table"]}' - ) - return - - st.session_state["eval_table_hash"] = get_table_hash( - conn=get_snowflake_connection(), - table_fqn=st.session_state["selected_eval_table"], - ) - st.session_state["eval_table_frame"] = fetch_table( - conn=get_snowflake_connection(), - table_fqn=st.session_state["selected_eval_table"], - ).set_index("ID") - - st.rerun() - - @st.experimental_dialog("Upload", width="small") def upload_dialog(content: str) -> None: def upload_handler(file_name: str) -> None: @@ -846,19 +605,6 @@ def set_up_requirements() -> None: help="Checking this box will enable you to add/edit join paths in your semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Reach out to your account team for access.", ) - # # TODOTZ - uncomment this block to use defaults for testing - # print("USING DEFAULTS FOR TESTING") - # st.session_state["snowflake_stage"] = SnowflakeStage( - # stage_database="TZAYATS", - # stage_schema="TZAYATS.TESTING", - # stage_name="TZAYATS.TESTING.MY_SEMANTIC_MODELS", - # ) - # st.session_state["file_name"] = "revenue_timeseries_update.yaml" - # st.session_state["page"] = GeneratorAppScreen.ITERATION - # st.session_state["experimental_features"] = experimental_features - # st.rerun() - - # TODOTZ - comment this block to use defaults for testing if st.button( "Submit", disabled=not st.session_state["selected_iteration_database"] @@ -917,387 +663,6 @@ def chat_settings_dialog() -> None: Note that the Cortex Analyst semantic model must be validated before integrating partner semantics.""" -def evaluation_mode_show() -> None: - if st.button("Set Evaluation Tables", on_click=clear_evaluation_data): - evaluation_data_dialog() - - if "validated" in st.session_state and not st.session_state["validated"]: - st.error("Please validate your semantic model before evaluating.") - return - - # TODO: find a less awkward way of specifying this. - if any( - key not in st.session_state - for key in ("selected_eval_table", "eval_table_hash", "eval_table_frame") - ): - st.error("Please set evaluation tables.") - return - - else: - results_table = st.session_state.get( - "selected_results_eval_old_table" - ) or st.session_state.get("selected_results_eval_new_table") - st.session_state["eval_timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S") - summary_stats = pd.DataFrame( - [ - ["Evaluation Timestamp", st.session_state["eval_timestamp"]], - ["Evaluation Table", st.session_state["selected_eval_table"]], - ["Evaluation Result Table", results_table], - ["Evaluation Table Hash", st.session_state["eval_table_hash"]], - ["Semantic Model YAML Hash", hash(st.session_state["working_yml"])], - ["Query Count", len(st.session_state["eval_table_frame"])], - ], - columns=["Summary Statistic", "Value"], - ) - st.dataframe(summary_stats, hide_index=True) - - send_analyst_requests() - run_sql_queries() - result_comparisons() - - -def send_analyst_requests() -> None: - def _get_content( - x: dict, item_type: str, key: str, default: str = "" # type: ignore[type-arg] - ) -> str: - result = next( - ( - item[key] - for item in x["message"]["content"] - if item["type"] == item_type - ), - default, - ) - return result - - eval_table_frame: pd.DataFrame = st.session_state["eval_table_frame"] - - total_requests = len(eval_table_frame) - progress_bar = st.progress(0) - status_text = st.empty() - start_time = time.time() - analyst_results = [] - - for i, (row_id, row_id) in enumerate(eval_table_frame.iterrows(), start=1): - status_text.text(f"Sending request {i}/{total_requests} to Analyst...") - messages = [ - {"role": "user", "content": [{"type": "text", "text": row_id["QUERY"]}]} - ] - semantic_model = proto_to_yaml(st.session_state.semantic_model) - try: - response = send_message( - _conn=get_snowflake_connection(), - semantic_model=semantic_model, - messages=messages, # type: ignore[arg-type] - ) - response_text = _get_content(response, item_type="text", key="text") - response_sql = _get_content(response, item_type="sql", key="statement") - analyst_results.append( - dict(ID=row_id, ANALYST_TEXT=response_text, ANALYST_SQL=response_sql) - ) - except Exception as e: - import traceback - - st.error(f"Problem with {row_id}: {e} \n{traceback.format_exc()}") - - progress_bar.progress(i / total_requests) - time.sleep(0.1) - - elapsed_time = time.time() - start_time - status_text.text( - f"All analyst requests received ✅ (Time taken: {elapsed_time:.2f} seconds)" - ) - - analyst_results_frame = pd.DataFrame(analyst_results).set_index("ID") - st.session_state["analyst_results_frame"] = analyst_results_frame - - -def run_sql_queries() -> None: - eval_table_frame: pd.DataFrame = st.session_state["eval_table_frame"] - analyst_results_frame = st.session_state["analyst_results_frame"] - - total_requests = len(eval_table_frame) - progress_bar = st.progress(0) - status_text = st.empty() - start_time = time.time() - - analyst_results = [] - gold_results = [] - - for i, (row_id, eval_row) in enumerate(eval_table_frame.iterrows(), start=1): - status_text.text(f"Evaluating Analyst query {i}/{total_requests}...") - - analyst_query = analyst_results_frame.loc[row_id, "ANALYST_SQL"] - analyst_result = execute_query( - conn=get_snowflake_connection(), query=analyst_query - ) - analyst_results.append(analyst_result) - - gold_query = eval_table_frame.loc[row_id, "GOLD_SQL"] - gold_result = execute_query(conn=get_snowflake_connection(), query=gold_query) - gold_results.append(gold_result) - - progress_bar.progress(i / total_requests) - time.sleep(0.1) - - st.session_state["query_results_frame"] = pd.DataFrame( - data=dict(ANALYST_RESULT=analyst_results, GOLD_RESULT=gold_results), - index=eval_table_frame.index, - ) - - elapsed_time = time.time() - start_time - status_text.text( - f"All analyst and gold queries run ✅ (Time taken: {elapsed_time:.2f} seconds)" - ) - - -def _match_series(analyst_frame: pd.DataFrame, gold_series: pd.Series) -> str | None: - """Determine which result frame column name matches the gold series. - - Args: - analyst_frame: the data generated from the LLM constructed user query - gold_series: a column from the data generated from the gold sql - - Returns: - if there is a match, the results column name, if not, None - """ - for analyst_col in analyst_frame: - assert isinstance(analyst_col, str) - try: - pd.testing.assert_series_equal( - left=analyst_frame[analyst_col], - right=gold_series, - check_names=False, - ) - return analyst_col - except AssertionError: - pass - - return None - - -def _results_contain_gold_data( - analyst_frame: pd.DataFrame, - gold_frame: pd.DataFrame, -) -> bool: - """Determine if result frame contains all the same values as a gold frame. - - Args: - analyst_frame: the data generated from the LLM constructed user query - gold_frame: the data generated from a gold sql query - - Returns: - a boolean indicating if the results contain the gold data - """ - if analyst_frame.shape[0] != gold_frame.shape[0]: - return False - - unmatched_result_cols = analyst_frame.columns - for gold_col in gold_frame: - matching_col = _match_series( - analyst_frame=analyst_frame[unmatched_result_cols], - gold_series=gold_frame[gold_col], - ) - if matching_col is None: - return False - else: - unmatched_result_cols = unmatched_result_cols.drop(matching_col) - - return True - - -def _llm_judge(frame: pd.DataFrame) -> pd.DataFrame: - - if frame.empty: - return pd.DataFrame({"EXPLANATION": [], "CORRECT": []}) - - # create prompt frame series - table_name = "__LLM_JUDGE_TEMP_TABLE" - col_name = "LLM_JUDGE_PROMPT" - - prompt_frame = frame.apply( - axis=1, - func=lambda x: LLM_JUDGE_PROMPT_TEMPLATE.format( - input_question=x["QUERY"], - frame1_str=x["ANALYST_RESULT"].to_string(index=False), - frame2_str=x["GOLD_RESULT"].to_string(index=False), - ), - ).to_frame(name=col_name) - conn = get_snowflake_connection() - _ = write_pandas( - conn=conn, - df=prompt_frame, - table_name=table_name, - auto_create_table=True, - table_type="temporary", - overwrite=True, - ) - - query = f""" - SELECT SNOWFLAKE.CORTEX.COMPLETE('mistral-large2', {col_name}) AS LLM_JUDGE - FROM {conn.database}.{conn.schema}.{table_name} - """ - cursor = conn.cursor() - cursor.execute(query) - llm_judge_frame = cursor.fetch_pandas_all() - llm_judge_frame.index = frame.index - - reason_filter = re.compile(r"REASON\:([\S\s]*?)ANSWER\:") - answer_filter = re.compile(r"ANSWER\:([\S\s]*?)$") - - def _safe_re_search(x, filter): # type: ignore[no-untyped-def] - try: - return re.search(filter, x).group(1).strip() # type: ignore[union-attr] - except Exception as e: - return f"Could Not Parse LLM Judge Response: {x} with error: {e}" - - llm_judge_frame["EXPLANATION"] = llm_judge_frame["LLM_JUDGE"].apply( - _safe_re_search, args=(reason_filter,) - ) - llm_judge_frame["CORRECT"] = ( - llm_judge_frame["LLM_JUDGE"] - .apply(_safe_re_search, args=(answer_filter,)) - .str.lower() - .eq("true") - ) - return llm_judge_frame - - -def visualize_eval_results(frame: pd.DataFrame) -> None: - n_questions = len(frame) - n_correct = frame["CORRECT"].sum() - accuracy = (n_correct / n_questions) * 100 - st.markdown( - f"###### Results: {n_correct} out of {n_questions} questions correct with accuracy {accuracy:.2f}%" - ) - for id, frame_row in frame.iterrows(): - match_emoji = "✅" if row["CORRECT"] else "❌" - with st.expander(f"Row ID: {id} {match_emoji}"): - st.write(f"Input Query: {frame_row['QUERY']}") - st.write(frame_row["ANALYST_TEXT"].replace("\n", " ")) - - col1, col2 = st.columns(2) - - with col1: - st.write("Analyst SQL") - st.code(frame_row["ANALYST_SQL"], language="sql") - - with col2: - st.write("Golden SQL") - st.code(frame_row["GOLD_SQL"], language="sql") - - col1, col2 = st.columns(2) - with col1: - if isinstance(frame_row["ANALYST_RESULT"], str): - st.error(frame_row["ANALYST_RESULT"]) - else: - st.write(frame_row["ANALYST_RESULT"]) - - with col2: - if isinstance(frame_row["GOLD_RESULT"], str): - st.error(frame_row["GOLD_RESULT"]) - else: - st.write(frame_row["GOLD_RESULT"]) - - st.write(f"**Explanation**: {frame_row['EXPLANATION']}") - - -def result_comparisons() -> None: - eval_table_frame: pd.DataFrame = st.session_state["eval_table_frame"] - analyst_results_frame = st.session_state["analyst_results_frame"] - query_results_frame = st.session_state["query_results_frame"] - - frame = pd.concat( - [eval_table_frame, analyst_results_frame, query_results_frame], axis=1 - ) - start_time = time.time() - status_text = st.empty() - - matches = pd.Series(False, index=frame.index) - explanations = pd.Series("", index=frame.index) - use_llm_judge = "" - - status_text.text("Checking for exact matches...") - for row_id, res_row in frame.iterrows(): - analyst_is_frame = isinstance(res_row["ANALYST_RESULT"], pd.DataFrame) - gold_is_frame = isinstance(res_row["GOLD_RESULT"], pd.DataFrame) - if (not analyst_is_frame) and (not gold_is_frame): - matches[row_id] = False - explanations[row_id] = dedent( - f""" - analyst sql had an error: {res_row["ANALYST_RESULT"]} - gold sql had an error: {res_row["GOLD_RESULT"]} - """ - ) - elif (not analyst_is_frame) and gold_is_frame: - matches[row_id] = False - explanations[row_id] = dedent( - f""" - analyst sql had an error: {res_row["ANALYST_RESULT"]} - """ - ) - elif analyst_is_frame and (not gold_is_frame): - matches[row_id] = False - explanations[row_id] = dedent( - f""" - gold sql had an error: {res_row["GOLD_RESULT"]} - """ - ) - else: - exact_match = _results_contain_gold_data( - analyst_frame=res_row["ANALYST_RESULT"], - gold_frame=res_row["GOLD_RESULT"], - ) - matches[row_id] = exact_match - explanations[row_id] = ( - "Data matches exactly" if exact_match else use_llm_judge - ) - - frame["CORRECT"] = matches - frame["EXPLANATION"] = explanations - - filtered_frame = frame[explanations == use_llm_judge] - - status_text.text("Calling LLM Judge...") - llm_judge_frame = _llm_judge(frame=filtered_frame) - - for col in ("CORRECT", "EXPLANATION"): - frame[col] = llm_judge_frame[col].combine_first(frame[col]) - - elapsed_time = time.time() - start_time - status_text.text( - f"Analyst and Gold Results Compared ✅ (Time taken: {elapsed_time:.2f} seconds)" - ) - - visualize_eval_results(frame) - - frame["TIMESTAMP"] = st.session_state["eval_timestamp"] - frame["EVAL_TABLE"] = st.session_state["selected_eval_table"] - frame["EVAL_TABLE_HASH"] = st.session_state["eval_table_hash"] - frame["MODEL_HASH"] = hash(st.session_state["working_yml"]) - - # Save results to frame as string - frame["ANALYST_RESULT"] = frame["ANALYST_RESULT"].apply( - lambda x: x.to_string(index=False) if isinstance(x, pd.DataFrame) else x - ) - frame["GOLD_RESULT"] = frame["GOLD_RESULT"].apply( - lambda x: x.to_string(index=False) if isinstance(x, pd.DataFrame) else x - ) - - frame = frame.reset_index()[list(RESULTS_TABLE_SCHEMA)] - write_pandas( - conn=get_snowflake_connection(), - df=frame, - table_name=st.session_state["selected_results_eval_table"], - overwrite=False, - quote_identifiers=False, - auto_create_table=False, - ) - st.write("Evaluation results stored in the database ✅") - - - - def show() -> None: init_session_states() diff --git a/semantic_model_generator/protos/semantic_model.proto b/semantic_model_generator/protos/semantic_model.proto index 75fc93f1..2edb0712 100644 --- a/semantic_model_generator/protos/semantic_model.proto +++ b/semantic_model_generator/protos/semantic_model.proto @@ -4,7 +4,7 @@ // python -m grpc_tools.protoc -I=semantic_model_generator/protos/ --python_out=semantic_model_generator/protos/ --pyi_out=semantic_model_generator/protos/ semantic_model_generator/protos/semantic_model.proto syntax = "proto3"; -package com.snowflake.cortex.analyst; +package semantic_model_generator; option java_outer_classname = "SemanticModelProto"; option go_package = "neeva.co/cortexsearch/chat/analyst"; @@ -58,40 +58,40 @@ message RetrievalResult { // e.g. `base_column1 + base_column2`. message Column { // A descriptive name for this column. - string name = 1 [ (id_field) = true ]; + string name = 1 [(id_field) = true]; // A list of other terms/phrases used to refer to this column. - repeated string synonyms = 2 [ (optional) = true ]; + repeated string synonyms = 2 [(optional) = true]; // A brief description about this column, including things like what data this // column has. - string description = 3 [ (optional) = true ]; + string description = 3 [(optional) = true]; // The SQL expression for this column. Could simply be a base table column // name or an arbitrary SQL expression over one or more columns of the base // table. - string expr = 4 [ (sql_expression) = true ]; + string expr = 4 [(sql_expression) = true]; // The data type of this column. string data_type = 5; // The kind of this column - dimension or fact, metric. ColumnKind kind = 6; // If true, assume that this column has unique values. - bool unique = 7 [ (optional) = true ]; + bool unique = 7 [(optional) = true]; // If no aggregation is specified, then this is the default aggregation // applied to this column in contxt of a grouping. - AggregationType default_aggregation = 8 [ (optional) = true ]; + AggregationType default_aggregation = 8 [(optional) = true]; // Sample values of this column. - repeated string sample_values = 9 [ (optional) = true ]; + repeated string sample_values = 9 [(optional) = true]; // Whether to index the values and retrieve them based on the question. // If False, all sample values will be used as input to the model. - bool index_and_retrieve_values = 10 [ (optional) = true ]; + bool index_and_retrieve_values = 10 [(optional) = true]; // Retrieved literals of this column. - repeated RetrievalResult retrieved_literals = 11 [ (optional) = true ]; + repeated RetrievalResult retrieved_literals = 11 [(optional) = true]; // A Cortex Search Service configured on this column to retrieve literals. string cortex_search_service_name = 12 - [ (optional) = true, deprecated = true ]; - CortexSearchService cortex_search_service = 13 [ (optional) = true ]; + [(optional) = true, deprecated = true]; + CortexSearchService cortex_search_service = 13 [(optional) = true]; // If true, this column has limited possible values, all of which are in // the sample_values field. - bool is_enum = 14 [ (optional) = true ]; + bool is_enum = 14 [(optional) = true]; } // Dimension columns contain categorical values (e.g. state, user_type, @@ -99,37 +99,37 @@ message Column { // context_to_column_format() of snowpilot/semantic_context/protos/schema.py. message Dimension { // A descriptive name for this dimension. - string name = 1 [ (id_field) = true ]; + string name = 1 [(id_field) = true]; // A list of other terms/phrases used to refer to this dimension. - repeated string synonyms = 2 [ (optional) = true ]; + repeated string synonyms = 2 [(optional) = true]; // A brief description about this dimension, including things like // what data this dimension has. - string description = 3 [ (optional) = true ]; + string description = 3 [(optional) = true]; // The SQL expression defining this dimension. Could simply be a physical // column name or an arbitrary SQL expression over one or more columns of the // physical table. - string expr = 4 [ (sql_expression) = true ]; + string expr = 4 [(sql_expression) = true]; // The data type of this dimension. string data_type = 5; // If true, assume that this dimension has unique values. - bool unique = 6 [ (optional) = true ]; + bool unique = 6 [(optional) = true]; // Sample values of this column. - repeated string sample_values = 7 [ (optional) = true ]; + repeated string sample_values = 7 [(optional) = true]; // A Cortex Search Service configured on this column to retrieve literals. - CortexSearchService cortex_search_service = 8 [ (optional) = true ]; + CortexSearchService cortex_search_service = 8 [(optional) = true]; string cortex_search_service_name = 9 - [ (optional) = true, deprecated = true ]; + [(optional) = true, deprecated = true]; // If true, this column has limited possible values, all of which are in // the sample_values field. - bool is_enum = 10 [ (optional) = true ]; + bool is_enum = 10 [(optional) = true]; } // Fully qualified Cortex Search Service name. message CortexSearchService { - string database = 1 [ (optional) = true ]; - string schema = 2 [ (optional) = true ]; + string database = 1 [(optional) = true]; + string schema = 2 [(optional) = true]; string service = 3; - string literal_column = 4 [ (optional) = true ]; + string literal_column = 4 [(optional) = true]; } // Time dimension columns contain time values (e.g. sale_date, created_at, @@ -137,22 +137,22 @@ message CortexSearchService { // to_column_format() of snowpilot/semantic_context/utils/utils.py. message TimeDimension { // A descriptive name for this time dimension. - string name = 1 [ (id_field) = true ]; + string name = 1 [(id_field) = true]; // A list of other terms/phrases used to refer to this time dimension. - repeated string synonyms = 2 [ (optional) = true ]; + repeated string synonyms = 2 [(optional) = true]; // A brief description about this time dimension, including things like // what data it has, the timezone of values, etc. - string description = 3 [ (optional) = true ]; + string description = 3 [(optional) = true]; // The SQL expression defining this time dimension. Could simply be a physical // column name or an arbitrary SQL expression over one or more columns of the // physical table. - string expr = 4 [ (sql_expression) = true ]; + string expr = 4 [(sql_expression) = true]; // The data type of this time dimension. string data_type = 5; // If true, assume that this time dimension has unique values. - bool unique = 6 [ (optional) = true ]; + bool unique = 6 [(optional) = true]; // Sample values of this time dimension. - repeated string sample_values = 7 [ (optional) = true ]; + repeated string sample_values = 7 [(optional) = true]; } // Measure columns contain numerical values (e.g. revenue, impressions, salary). @@ -160,23 +160,23 @@ message TimeDimension { // to_column_format() of snowpilot/semantic_context/utils/utils.py. message Fact { // A descriptive name for this measure. - string name = 1 [ (id_field) = true ]; + string name = 1 [(id_field) = true]; // A list of other terms/phrases used to refer to this measure. - repeated string synonyms = 2 [ (optional) = true ]; + repeated string synonyms = 2 [(optional) = true]; // A brief description about this measure, including things like what data // it has. - string description = 3 [ (optional) = true ]; + string description = 3 [(optional) = true]; // The SQL expression defining this measure. Could simply be a physical column // name or an arbitrary SQL expression over one or more physical columns of // the underlying physical table. - string expr = 4 [ (sql_expression) = true ]; + string expr = 4 [(sql_expression) = true]; // The data type of this measure. string data_type = 5; // If no aggregation is specified, then this is the default aggregation // applied to this measure in contxt of a grouping. - AggregationType default_aggregation = 6 [ (optional) = true ]; + AggregationType default_aggregation = 6 [(optional) = true]; // Sample values of this measure. - repeated string sample_values = 7 [ (optional) = true ]; + repeated string sample_values = 7 [(optional) = true]; } // Filter represents a named SQL expression that's used for filtering. @@ -184,12 +184,12 @@ message NamedFilter { // A descriptive name for this filter. string name = 1; // A list of other term/phrases used to refer to this column. - repeated string synonyms = 2 [ (optional) = true ]; + repeated string synonyms = 2 [(optional) = true]; // A brief description about this column, including details of what this // filter is typically used for. - string description = 3 [ (optional) = true ]; + string description = 3 [(optional) = true]; // The SQL expression of this filter. - string expr = 4 [ (sql_expression) = true ]; + string expr = 4 [(sql_expression) = true]; } // FullyQualifiedTable is used to represent three part table names - @@ -228,12 +228,12 @@ message ForeignKey { // table and/or introduce new derived columns. message Table { // A descriptive name for this table. - string name = 1 [ (id_field) = true ]; + string name = 1 [(id_field) = true]; // A list of other term/phrases used to refer to this table. - repeated string synonyms = 2 [ (optional) = true ]; + repeated string synonyms = 2 [(optional) = true]; // A brief description of this table, including details of what kinds of // analysis is it typically used for. - string description = 3 [ (optional) = true ]; + string description = 3 [(optional) = true]; // Fully qualified name of the underlying base table. FullyQualifiedTable base_table = 4; @@ -243,19 +243,19 @@ message Table { // For the external facing yaml specification, we have chosen to go with (2). // However, for the time being we'll support both (1) and (2) and continue // using (1) as the internal representation. - repeated Column columns = 5 [ (optional) = true ]; - repeated Dimension dimensions = 9 [ (optional) = true ]; - repeated TimeDimension time_dimensions = 10 [ (optional) = true ]; - repeated Fact measures = 11 [ (optional) = true, deprecated = true ]; - repeated Fact facts = 12 [ (optional) = true ]; - repeated Metric metrics = 13 [ (optional) = true ]; + repeated Column columns = 5 [(optional) = true]; + repeated Dimension dimensions = 9 [(optional) = true]; + repeated TimeDimension time_dimensions = 10 [(optional) = true]; + repeated Fact measures = 11 [(optional) = true, deprecated = true]; + repeated Fact facts = 12 [(optional) = true]; + repeated Metric metrics = 13 [(optional) = true]; // Primary key of the table, if any. - PrimaryKey primary_key = 6 [ (optional) = true ]; + PrimaryKey primary_key = 6 [(optional) = true]; // Foreign keys of the table, if any. - repeated ForeignKey foreign_keys = 7 [ (optional) = true ]; + repeated ForeignKey foreign_keys = 7 [(optional) = true]; // Predefined filters on this table, if any. - repeated NamedFilter filters = 8 [ (optional) = true ]; + repeated NamedFilter filters = 8 [(optional) = true]; // NEXT_TAG: 14. } @@ -265,30 +265,30 @@ message Table { // tables. message Metric { // A descriptive name of the metric. - string name = 1 [ (id_field) = true ]; + string name = 1 [(id_field) = true]; // A list of other term/phrases used to refer to this metric. - repeated string synonyms = 2 [ (optional) = true ]; + repeated string synonyms = 2 [(optional) = true]; // A brief description of this metric, including details of what it computes. - string description = 3 [ (optional) = true ]; + string description = 3 [(optional) = true]; // The SQL expression to compute this metric. // All columns used must be fully qualified with the logical table name. // Expression must be an aggregate - string expr = 4 [ (sql_expression) = true ]; + string expr = 4 [(sql_expression) = true]; // The filter associated with this metric. // Do not expose this for now. - MetricsFilter filter = 5 [ (optional) = true ]; + MetricsFilter filter = 5 [(optional) = true]; } -message MetricsFilter { string expr = 1 [ (sql_expression) = true ]; } +message MetricsFilter { string expr = 1 [(sql_expression) = true]; } // Type of the join - inner, left outer, etc. enum JoinType { join_type_unknown = 0; inner = 1; left_outer = 2; - full_outer = 3 [ deprecated = true ]; - cross = 4 [ deprecated = true ]; - right_outer = 5 [ deprecated = true ]; + full_outer = 3 [deprecated = true]; + cross = 4 [deprecated = true]; + right_outer = 5 [deprecated = true]; } // Type of the relationship - one-to-one, many-to-one, etc. @@ -296,8 +296,8 @@ enum RelationshipType { relationship_type_unknown = 0; one_to_one = 1; many_to_one = 2; - one_to_many = 3 [ deprecated = true ]; - many_to_many = 4 [ deprecated = true ]; + one_to_many = 3 [deprecated = true]; + many_to_many = 4 [deprecated = true]; } message RelationKey { @@ -315,9 +315,9 @@ message Relationship { // The right hand side table of the join. string right_table = 3; // The expression used to join left and right tables. Only used internally. - string expr = 4 [ (sql_expression) = true, (optional) = true ]; + string expr = 4 [(sql_expression) = true, (optional) = true]; // Keys directly represent the join relationship. - repeated RelationKey relationship_columns = 7 [ (optional) = true ]; + repeated RelationKey relationship_columns = 7 [(optional) = true]; // Type of the join. JoinType join_type = 5; // Type of the relationship. @@ -331,15 +331,15 @@ message SemanticModel { string name = 1; // A brief description of this project, including details of what kind of // analysis does this project enable. - string description = 2 [ (optional) = true ]; + string description = 2 [(optional) = true]; // List of tables in this project. repeated Table tables = 3; // List of relationships in this project. - repeated Relationship relationships = 5 [ (optional) = true ]; + repeated Relationship relationships = 5 [(optional) = true]; // List of verified queries for this semantic model. - repeated VerifiedQuery verified_queries = 6 [ (optional) = true ]; + repeated VerifiedQuery verified_queries = 6 [(optional) = true]; // Custom instructions that will be applied to the final SQL generation. - string custom_instructions = 7 [ (optional) = true ]; + string custom_instructions = 7 [(optional) = true]; } // VerifiedQuery represents a (question, sql) pair that has been manually @@ -348,19 +348,19 @@ message VerifiedQuery { // A name for this verified query. Mainly used for display purposes. string name = 1; // The name of the semantic model on which this verified query is based off. - string semantic_model_name = 2 [ (optional) = true ]; + string semantic_model_name = 2 [(optional) = true]; // The question being answered. string question = 3; // The correct SQL query for answering the question. - string sql = 4 [ (sql_expression) = true ]; + string sql = 4 [(sql_expression) = true]; // Timestamp at which the query was last verified - measures in seconds since // epoch, in UTC. - int64 verified_at = 5 [ (optional) = true ]; + int64 verified_at = 5 [(optional) = true]; // Name of the person who verified this query. - string verified_by = 6 [ (optional) = true ]; + string verified_by = 6 [(optional) = true]; // Whether to always include in this question in the suggested questions // module - bool use_as_onboarding_question = 7 [ (optional) = true ]; + bool use_as_onboarding_question = 7 [(optional) = true]; } // VerifiedQueryRepository is a simply a collection of verified queries. diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index 70e0a3c1..08c7d684 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -519,24 +519,23 @@ def create_table_in_schema( def get_valid_schemas_tables_columns_df( conn: SnowflakeConnection, table_fqn: str ) -> pd.DataFrame: - if table_names and not table_schema: - logger.warning( - "Provided table_name without table_schema, cannot filter to fetch the specific table" - ) - where_clause = "" - if table_schema: - where_clause += f" where t.table_schema ilike '{table_schema}' " - if table_names: - table_names_str = ", ".join([f"'{t.lower()}'" for t in table_names]) - where_clause += f"AND LOWER(t.table_name) in ({table_names_str}) " - query = dedent( - f""" - select t.{_TABLE_SCHEMA_COL}, t.{_TABLE_NAME_COL}, c.{_COLUMN_NAME_COL}, c.{_DATATYPE_COL}, c.{_COMMENT_COL} as {_COLUMN_COMMENT_ALIAS} - from {db_name}.information_schema.tables as t - join {db_name}.information_schema.columns as c on t.table_schema = c.table_schema and t.table_name = c.table_name{where_clause} - order by 1, 2, c.ordinal_position - """ - ) + database_name, schema_name, table_name = table_fqn.split(".") + + query = f""" + select + c.{_COLUMN_NAME_COL}, + c.{_DATATYPE_COL}, + c.{_COMMENT_COL} as {_COLUMN_COMMENT_ALIAS}, + t.{_COMMENT_COL} as {_TABLE_COMMENT_COL} + from {database_name}.information_schema.tables as t + join {database_name}.information_schema.columns as c + on true + and t.table_schema = c.table_schema + and t.table_name = c.table_name + and t.table_name ilike '{table_name}' + where t.table_schema ilike '{schema_name}' + order by c.ordinal_position + """ cursor_execute = conn.cursor().execute(query) columns_df = cursor_execute.fetch_pandas_all() # type: ignore[union-attr] return columns_df