Skip to content

Commit

Permalink
Fix: Replace table prefix as well (#44)
Browse files Browse the repository at this point in the history
* Fix: Replace table prefix as well

* Update src/sql_mock/helpers.py

Co-authored-by: fgarzadeleon <[email protected]>

* Update CHANGELOG.md

---------

Co-authored-by: fgarzadeleon <[email protected]>
  • Loading branch information
Somtom and fgarzadeleon authored Feb 14, 2024
1 parent 5f56851 commit de859a5
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

* Add default target path for dbt
* Improve replacement of tables (also taking into account missing alias)

## [0.6.0]

Expand Down
47 changes: 44 additions & 3 deletions src/sql_mock/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import TYPE_CHECKING, List

import sqlglot
from sqlglot.expressions import replace_tables, to_table
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.expressions import replace_tables
from sqlglot.optimizer.scope import build_scope

from sql_mock.exceptions import ValidationError
Expand All @@ -16,6 +16,7 @@
def get_keys_from_list_of_dicts(data: list[dict]) -> set[str]:
return set(key for dictionary in data for key in dictionary.keys())


def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlglot.Expression:
"""
Remove a CTE from a query
Expand All @@ -29,7 +30,40 @@ def remove_cte_from_query(query_ast: sqlglot.Expression, cte_name: str) -> sqlgl
cte.pop()
return query_ast

def replace_original_table_references(query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str) -> sqlglot.Expression:

def _replace_table_ref_in_columns(
query_ast: sqlglot.Expression, table_ref: str, new_ref: str, dialect: str
) -> sqlglot.Expression:
"""
Replace original table reference with a new ref
Args:
query_ast (str): Original SQL query - parsed by sqlglot
table_ref (str): Table ref to be replaced
new_ref (str): Name of the new table ref
"""
ref_table = to_table(table_ref, dialect=dialect)

root = build_scope(query_ast)
cols_in_query = [col for scope in root.traverse() for col in scope.columns]
for col in cols_in_query:
if not col.table:
continue
# For column replacement we simplify the comparison to the table name
# which is why we cast the col.table string to a table object
col_table = to_table(col.table, dialect=dialect)
if col_table.name == ref_table.name:
col.set("table", new_ref)
# Make sure to remove the schema and db from the col table reference
# to fully exchange it with the provided table ref
col.set("schema", None)
col.set("db", None)
return query_ast


def replace_original_table_references(
query_ast: sqlglot.Expression, table_ref: str, sql_mock_cte_name: str, dialect: str
) -> sqlglot.Expression:
"""
Replace original table reference with sql mock cte name to point them to the mocked data
Expand All @@ -39,6 +73,9 @@ def replace_original_table_references(query_ast: sqlglot.Expression, table_ref:
sql_mock_cte_name (str): Name of the CTE that will contain the mocked data
dialect (str): The SQL dialect to use for parsing the query
"""
query_ast = _replace_table_ref_in_columns(
query_ast=query_ast, table_ref=table_ref, new_ref=sql_mock_cte_name, dialect=dialect
)
return replace_tables(expression=query_ast, mapping={table_ref: sql_mock_cte_name}, dialect=dialect)


Expand Down Expand Up @@ -145,7 +182,11 @@ def validate_all_input_mocks_for_query_provided(query: str, input_mocks: List["B
input_mocks (List[BaseTableMock]): The input mocks that are provided
dialect (str): The SQL dialect to use for parsing the query
"""
provided_table_refs = [table_mock._sql_mock_meta.table_ref for table_mock in input_mocks if hasattr(table_mock._sql_mock_meta, "table_ref")]
provided_table_refs = [
table_mock._sql_mock_meta.table_ref
for table_mock in input_mocks
if hasattr(table_mock._sql_mock_meta, "table_ref")
]
ast = sqlglot.parse_one(query, dialect=dialect)

# In case the table_ref is a CTE, we need to remove it from the query
Expand Down
27 changes: 20 additions & 7 deletions tests/sql_mock/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ def test_replace_original_table_references_when_reference_does_not_exist(self):
dialect="bigquery",
).sql(pretty=True)

def test_replace_original_table_reference_when_used_in_col_ref(self):
"""...then the column reference should also be replaced"""
query = f"""
SELECT {MockTestTable._sql_mock_meta.table_ref}.col1
FROM data.some_table as b
JOIN {MockTestTable._sql_mock_meta.table_ref} ON {MockTestTable._sql_mock_meta.table_ref}.col1 = b.col1
"""
expected = f"SELECT\n {MockTestTable._sql_mock_meta.cte_name}.col1\nFROM data.some_table AS b\nJOIN {MockTestTable._sql_mock_meta.cte_name} /* data.mock_test_table */\n ON {MockTestTable._sql_mock_meta.cte_name}.col1 = b.col1"
assert expected == replace_original_table_references(
query_ast=sqlglot.parse_one(query),
table_ref=MockTestTable._sql_mock_meta.table_ref,
sql_mock_cte_name=MockTestTable._sql_mock_meta.cte_name,
dialect="bigquery",
).sql(pretty=True)


class TestSelectFromCTE:
def test_select_from_cte_when_cte_exists(self):
Expand Down Expand Up @@ -213,13 +228,12 @@ def test_input_mocks_missing_for_tables_within_mocked_cte(self):
SELECT a, b, * FROM cte_2
"""

@table_meta(table_ref="cte_1")
class Cte1Mock(BaseTableMock):
pass

validate_all_input_mocks_for_query_provided(
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
)
validate_all_input_mocks_for_query_provided(query=query, input_mocks=[Cte1Mock()], dialect="bigquery")

def test_cte_superfluous_after_mocking(self):
"""...then the validation should pass since the CTE will be removed anyways and does not need to be mocked"""
Expand All @@ -236,13 +250,12 @@ def test_cte_superfluous_after_mocking(self):
SELECT a, b, * FROM cte_2
"""
@table_meta(table_ref="cte_2") # This will make cte_1 superfluous

@table_meta(table_ref="cte_2") # This will make cte_1 superfluous
class Cte1Mock(BaseTableMock):
pass

validate_all_input_mocks_for_query_provided(
query=query, input_mocks=[Cte1Mock()], dialect="bigquery"
)
validate_all_input_mocks_for_query_provided(query=query, input_mocks=[Cte1Mock()], dialect="bigquery")


class TestGetSourceTables:
Expand Down

0 comments on commit de859a5

Please sign in to comment.