Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Replace table prefix as well #44

Merged
merged 3 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

* Add default target path for dbt
* Improve replacement of tables (also taking into account missing alias)

## [0.6.0]

Expand Down
47 changes: 44 additions & 3 deletions src/sql_mock/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import TYPE_CHECKING, List

import sqlglot
from sqlglot.expressions import replace_tables, to_table
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.expressions import replace_tables
from sqlglot.optimizer.scope import build_scope

from sql_mock.exceptions import ValidationError
Expand All @@ -16,6 +16,7 @@
def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]:
return set(key for dictionary in data for key in dictionary.keys())


def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlglot.Expression:
"""
Remove a CTE from a query
Expand All @@ -29,7 +30,40 @@ def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlgl
cte.pop()
return query_ast

def replace_original_table_references(query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str) -> sqlglot.Expression:

def _replace_table_ref_in_columns(
query_ast: sqlglot.Expression, table_ref: str, new_ref: str, dialect: str
) -> sqlglot.Expression:
"""
Replace original table reference with a new ref

Args:
query_ast (str): Original SQL query - parsed by sqlglot
table_ref (str): Table ref to be replaced
new_ref (str): Name of the new table ref
"""
ref_table = to_table(table_ref, dialect=dialect)

root = build_scope(query_ast)
cols_in_query = [col for scope in root.traverse() for col in scope.columns]
for col in cols_in_query:
if not col.table:
continue
# For column replacement we simplify the comparison to the table name
# which is why we cast the col.table string to a table object
col_table = to_table(col.table, dialect=dialect)
if col_table.name == ref_table.name:
col.set("table", new_ref)
# Make sure to remove the schema and db from the col table reference
# to fully exchange it with the provided table ref
col.set("schema", None)
col.set("db", None)
return query_ast


def replace_original_table_references(
query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str
) -> sqlglot.Expression:
"""
Replace original table reference with sql mock cte name to point them to the mocked data

Expand All @@ -39,6 +73,9 @@ def replace_original_table_references(query_ast: sqlglot.Expression, table_ref:
sql_mock_cte_name (str): Name of the CTE that will contain the mocked data
dialect (str): The SQL dialect to use for parsing the query
"""
query_ast = _replace_table_ref_in_columns(
query_ast=query_ast, table_ref=table_ref, new_ref=sql_mock_cte_name, dialect=dialect
)
return replace_tables(expression=query_ast, mapping={table_ref: sql_mock_cte_name}, dialect=dialect)


Expand Down Expand Up @@ -145,7 +182,11 @@ def validate_all_input_mocks_for_query_provided(query: str, input_mocks: List["B
input_mocks (List[BaseTableMock]): The input mocks that are provided
dialect (str): The SQL dialect to use for parsing the query
"""
provided_table_refs = [table_mock._sql_mock_meta.table_ref for table_mock in input_mocks if hasattr(table_mock._sql_mock_meta, "table_ref")]
provided_table_refs = [
table_mock._sql_mock_meta.table_ref
for table_mock in input_mocks
if hasattr(table_mock._sql_mock_meta, "table_ref")
]
ast = sqlglot.parse_one(query, dialect=dialect)

# In case the table_ref is a CTE, we need to remove it from the query
Expand Down
27 changes: 20 additions & 7 deletions tests/sql_mock/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def test_replace_original_table_references_when_reference_does_not_exist(self):
dialect="bigquery",
).sql(pretty=True)

def test_replace_original_table_reference_when_used_in_col_ref(self):
"""...then the column reference should also be replaced"""
query = f"""
SELECT {MockTestTable._sql_mock_meta.table_ref}.col1
FROM data.some_table as b
JOIN {MockTestTable._sql_mock_meta.table_ref} ON {MockTestTable._sql_mock_meta.table_ref}.col1 = b.col1
"""
expected = f"SELECT\n {MockTestTable._sql_mock_meta.cte_name}.col1\nFROM data.some_table AS b\nJOIN {MockTestTable._sql_mock_meta.cte_name} /* data.mock_test_table */\n ON {MockTestTable._sql_mock_meta.cte_name}.col1 = b.col1"
assert expected == replace_original_table_references(
query_ast=sqlglot.parse_one(query),
table_ref=MockTestTable._sql_mock_meta.table_ref,
sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name,
dialect="bigquery",
).sql(pretty=True)


class TestSelectFromCTE:
def test_select_from_cte_when_cte_exists(self):
Expand Down Expand Up @@ -213,13 +228,12 @@ def test_input_mocks_missing_for_tables_within_mocked_cte(self):

SELECT a, b, * FROM cte_2
"""

@table_meta(table_ref="cte_1")
class Cte1Mock(BaseTableMock):
pass

validate_all_input_mocks_for_query_provided(
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
)
validate_all_input_mocks_for_query_provided(query=query, input_mocks=[Cte1Mock()], dialect="bigquery")

def test_cte_superfluous_after_mocking(self):
"""...then the validation should pass since the CTE will be removed anyways and does not need to be mocked"""
Expand All @@ -236,13 +250,12 @@ def test_cte_superfluous_after_mocking(self):

SELECT a, b, * FROM cte_2
"""
@table_meta(table_ref="cte_2") # This will make cte_1 superfluous

@table_meta(table_ref="cte_2") # This will make cte_1 superfluous
class Cte1Mock(BaseTableMock):
pass

validate_all_input_mocks_for_query_provided(
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
)
validate_all_input_mocks_for_query_provided(query=query, input_mocks=[Cte1Mock()], dialect="bigquery")


class TestGetSourceTables:
Expand Down
Loading