Skip to content

Commit

Permalink
Make table ref parsing more consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Schmidt committed Dec 29, 2023
1 parent 20059d7 commit 1827259
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
7 changes: 6 additions & 1 deletion src/sql_mock/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def select_from_cte(query: str, cte_name: str, sql_dialect: str):
cte_name (str): Name of the CTE to select from
sql_dialect (str): The sql dialect to use for generating the query
"""
ast = sqlglot.parse_one(query)
ast = sqlglot.parse_one(query, dialect=sql_dialect)

# Check whether the cte exists, if not raise an error
cte_exists = any(cte.alias == cte_name for cte in ast.find_all(sqlglot.exp.CTE))
Expand All @@ -61,6 +61,11 @@ def select_from_cte(query: str, cte_name: str, sql_dialect: str):
return adjusted_query


def parse_table_refs(table_ref, dialect):
"""Method to standardize how we parse table refs to avoid differences"""
return str(sqlglot.parse_one(table_ref, dialect=dialect))


def _strip_alias_transformer(node):
node.set("alias", None)
return node
Expand Down
54 changes: 27 additions & 27 deletions src/sql_mock/table_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,14 @@
from sql_mock.constants import NO_INPUT
from sql_mock.helpers import (
get_keys_from_list_of_dicts,
parse_table_refs,
replace_original_table_references,
select_from_cte,
validate_all_input_mocks_for_query_provided,
validate_input_mocks,
)


class MockTableMeta(BaseModel):
"""
Class to store static metadata of BaseMockTable instances which is used during processing.
We use this class to avoid collision with field names of the table we want to mock.
Attributes:
table_ref (string) : String that represents the table reference to the original table.
query (string): Srting of the SQL query (can be in Jinja format).
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

default_inputs: List[SkipValidation["BaseMockTable"]] = None
table_ref: str = None
query: str = None

@property
def cte_name(self):
if getattr(self, "table_ref", None):
return self.table_ref.replace(".", "__")


def table_meta(
table_ref: str = "", query_path: str = None, query: str = None, default_inputs: ["BaseMockTable"] = None
):
Expand All @@ -51,13 +30,13 @@ def table_meta(
"""

def decorator(cls):
mock_meta_kwargs = {"table_ref": table_ref}
mock_meta_kwargs = {"table_ref": parse_table_refs(table_ref, dialect=cls._sql_dialect)}

if query_path:
with open(query_path) as f:
mock_meta_kwargs['query'] = f.read()
mock_meta_kwargs["query"] = f.read()
elif query:
mock_meta_kwargs['query'] = query
mock_meta_kwargs["query"] = query

if default_inputs:
validate_input_mocks(default_inputs)
Expand All @@ -84,7 +63,6 @@ class SQLMockData(BaseModel):
last_query: str = None



class BaseMockTable:
"""
Represents a base class for creating mock database tables for testing.
Expand All @@ -98,7 +76,7 @@ class BaseMockTable:
"""

_sql_mock_data: SQLMockData = None
_sql_mock_meta: MockTableMeta = None
_sql_mock_meta: "MockTableMeta" = None
_sql_dialect: str = None

def __init__(self, data: list[dict] = None, sql_mock_data: SQLMockData = None) -> None:
Expand Down Expand Up @@ -341,3 +319,25 @@ def assert_equal(
ignore_order=ignore_order,
print_query_on_fail=print_query_on_fail,
)


class MockTableMeta(BaseModel):
"""
Class to store static metadata of BaseMockTable instances which is used during processing.
We use this class to avoid collision with field names of the table we want to mock.
Attributes:
table_ref (string) : String that represents the table reference to the original table.
query (string): Srting of the SQL query (can be in Jinja format).
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

default_inputs: List[SkipValidation["BaseMockTable"]] = None
table_ref: str = None
query: str = None

@property
def cte_name(self):
if getattr(self, "table_ref", None):
return self.table_ref.replace('"', "").replace(".", "__")
4 changes: 2 additions & 2 deletions tests/sql_mock/test_table_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_query_path_provided(mocker):
# Configure the mock to return the file content
mock_open.return_value.__enter__.return_value.read.return_value = query

@table_meta(table_ref="", query_path=query_path)
@table_meta(table_ref="some.table", query_path=query_path)
class TestMock(BaseMockTable):
pass

Expand All @@ -20,7 +20,7 @@ class TestMock(BaseMockTable):
def test_no_query_path_provided():
"""...then there should not be any query string stored on the cls._sql_mock_data"""

@table_meta(table_ref="")
@table_meta(table_ref="some.table")
class TestMock(BaseMockTable):
pass

Expand Down

0 comments on commit 1827259

Please sign in to comment.