diff --git a/CHANGELOG.md b/CHANGELOG.md index 0624efd..7f2bc49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed * Add default target path for dbt +* Improve replacement of tables (also taking into account missing alias) ## [0.6.0] diff --git a/src/sql_mock/helpers.py b/src/sql_mock/helpers.py index 2549173..6b6443e 100644 --- a/src/sql_mock/helpers.py +++ b/src/sql_mock/helpers.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, List import sqlglot +from sqlglot.expressions import replace_tables, to_table from sqlglot.optimizer.eliminate_ctes import eliminate_ctes -from sqlglot.expressions import replace_tables from sqlglot.optimizer.scope import build_scope from sql_mock.exceptions import ValidationError @@ -16,6 +16,7 @@ def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]: return set(key for dictionary in data for key in dictionary.keys()) + def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlglot.Expression: """ Remove a CTE from a query @@ -29,7 +30,40 @@ def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlgl cte.pop() return query_ast -def replace_original_table_references(query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str) -> sqlglot.Expression: + +def _replace_table_ref_in_columns( + query_ast: sqlglot.Expression, table_ref: str, new_ref: str, dialect: str +) -> sqlglot.Expression: + """ + Replace original table reference with a new ref + + Args: + query_ast (str): Original SQL query - parsed by sqlglot + table_ref (str): Table ref to be replaced + new_ref (str): Name of the new table ref + """ + ref_table = to_table(table_ref, dialect=dialect) + + root = build_scope(query_ast) + cols_in_query = [col for scope in root.traverse() for col in scope.columns] + for col in cols_in_query: + if not col.table: + continue + # For column replacement we simplify the comparison to the table name + # which is why we cast the col.table string to a table object + col_table = to_table(col.table, dialect=dialect) + if col_table.name == ref_table.name: + col.set("table", new_ref) + # Make sure to remove the schema and db from the col table reference + # to fully exchange it with the provided table ref + col.set("schema", None) + col.set("db", None) + return query_ast + + +def replace_original_table_references( + query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str +) -> sqlglot.Expression: """ Replace original table reference with sql mock cte name to point them to the mocked data @@ -39,6 +73,9 @@ def replace_original_table_references(query_ast: sqlglot.Expression, table_ref: sql_mock_cte_name (str): Name of the CTE that will contain the mocked data dialect (str): The SQL dialect to use for parsing the query """ + query_ast = _replace_table_ref_in_columns( + query_ast=query_ast, table_ref=table_ref, new_ref=sql_mock_cte_name, dialect=dialect + ) return replace_tables(expression=query_ast, mapping={table_ref: sql_mock_cte_name}, dialect=dialect) @@ -145,7 +182,11 @@ def validate_all_input_mocks_for_query_provided(query: str, input_mocks: List["B input_mocks (List[BaseTableMock]): The input mocks that are provided dialect (str): The SQL dialect to use for parsing the query """ - provided_table_refs = [table_mock._sql_mock_meta.table_ref for table_mock in input_mocks if hasattr(table_mock._sql_mock_meta, "table_ref")] + provided_table_refs = [ + table_mock._sql_mock_meta.table_ref + for table_mock in input_mocks + if hasattr(table_mock._sql_mock_meta, "table_ref") + ] ast = sqlglot.parse_one(query, dialect=dialect) # In case the table_ref is a CTE, we need to remove it from the query diff --git a/tests/sql_mock/test_helpers.py b/tests/sql_mock/test_helpers.py index e5b2ff5..e59700d 100644 --- a/tests/sql_mock/test_helpers.py +++ b/tests/sql_mock/test_helpers.py @@ -69,6 +69,21 @@ def test_replace_original_table_references_when_reference_does_not_exist(self): dialect="bigquery", ).sql(pretty=True) + def test_replace_original_table_reference_when_used_in_col_ref(self): + """...then the column reference should also be replaced""" + query = f""" + SELECT {MockTestTable._sql_mock_meta.table_ref}.col1 + FROM data.some_table as b + JOIN {MockTestTable._sql_mock_meta.table_ref} ON {MockTestTable._sql_mock_meta.table_ref}.col1 = b.col1 + """ + expected = f"SELECT\n {MockTestTable._sql_mock_meta.cte_name}.col1\nFROM data.some_table AS b\nJOIN {MockTestTable._sql_mock_meta.cte_name} /* data.mock_test_table */\n ON {MockTestTable._sql_mock_meta.cte_name}.col1 = b.col1" + assert expected == replace_original_table_references( + query_ast=sqlglot.parse_one(query), + table_ref=MockTestTable._sql_mock_meta.table_ref, + sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name, + dialect="bigquery", + ).sql(pretty=True) + class TestSelectFromCTE: def test_select_from_cte_when_cte_exists(self): @@ -213,13 +228,12 @@ def test_input_mocks_missing_for_tables_within_mocked_cte(self): SELECT a, b, * FROM cte_2 """ + @table_meta(table_ref="cte_1") class Cte1Mock(BaseTableMock): pass - validate_all_input_mocks_for_query_provided( - query=query, input_mocks=[Cte1Mock()], dialect="bigquery" - ) + validate_all_input_mocks_for_query_provided(query=query, input_mocks=[Cte1Mock()], dialect="bigquery") def test_cte_superfluous_after_mocking(self): """...then the validation should pass since the CTE will be removed anyways and does not need to be mocked""" @@ -236,13 +250,12 @@ def test_cte_superfluous_after_mocking(self): SELECT a, b, * FROM cte_2 """ - @table_meta(table_ref="cte_2") # This will make cte_1 superfluous + + @table_meta(table_ref="cte_2") # This will make cte_1 superfluous class Cte1Mock(BaseTableMock): pass - validate_all_input_mocks_for_query_provided( - query=query, input_mocks=[Cte1Mock()], dialect="bigquery" - ) + validate_all_input_mocks_for_query_provided(query=query, input_mocks=[Cte1Mock()], dialect="bigquery") class TestGetSourceTables: