Skip to content

Commit

Permalink
Move more attributes to Metadata class
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Schmidt committed Dec 28, 2023
1 parent 206eb01 commit 20059d7
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 37 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@ default_install_hook_types:
- pre-commit
- pre-push
repos:
- repo: https://github.com/dhruvmanila/remove-print-statements
rev: 'v0.4.0' # Replace with latest tag on GitHub
hooks:
- id: remove-print-statements
exclude: "^(.gitlab|.scripts|sisyphus/test_tasks|packages/achtung)"
args: ['--verbose'] # Show all the print statements to be removed
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
Expand Down
6 changes: 3 additions & 3 deletions src/sql_mock/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def replace_original_table_references(query: str, mock_tables: list["BaseMockTab
dialect (str): The SQL dialect to use for parsing the query
"""
ast = sqlglot.parse_one(query, dialect=dialect)
mapping = {mock_table._sql_mock_data.table_ref: mock_table._sql_mock_data.cte_name for mock_table in mock_tables}
mapping = {mock_table._sql_mock_meta.table_ref: mock_table._sql_mock_meta.cte_name for mock_table in mock_tables}
res = replace_tables(expression=ast, mapping=mapping, dialect=dialect).sql(pretty=True, dialect=dialect)
return res

Expand Down Expand Up @@ -105,7 +105,7 @@ def _validate_input_mocks_have_table_ref(input_mocks: List["BaseMockTable"]) ->
missing_table_refs = [
type(mock_table).__name__
for mock_table in input_mocks
if not getattr(mock_table._sql_mock_data, "table_ref", False)
if not getattr(mock_table._sql_mock_meta, "table_ref", False)
]

if missing_table_refs:
Expand All @@ -122,7 +122,7 @@ def validate_input_mocks(input_mocks: List["BaseMockTable"]):
def validate_all_input_mocks_for_query_provided(query: str, input_mocks: List["BaseMockTable"], dialect: str) -> None:
missing_source_table_mocks = get_source_tables(query=query, dialect=dialect)
for mock_table in input_mocks:
table_ref = getattr(mock_table._sql_mock_data, "table_ref", None)
table_ref = getattr(mock_table._sql_mock_meta, "table_ref", None)
# If the table exists as mock, we can remove it from missing source tables
try:
missing_source_table_mocks.remove(table_ref)
Expand Down
54 changes: 36 additions & 18 deletions src/sql_mock/table_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@
)


class MockTableMeta(BaseModel):
"""
Class to store static metadata of BaseMockTable instances which is used during processing.
We use this class to avoid collision with field names of the table we want to mock.
Attributes:
table_ref (string) : String that represents the table reference to the original table.
query (string): Srting of the SQL query (can be in Jinja format).
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

default_inputs: List[SkipValidation["BaseMockTable"]] = None
table_ref: str = None
query: str = None

@property
def cte_name(self):
if getattr(self, "table_ref", None):
return self.table_ref.replace(".", "__")


def table_meta(
table_ref: str = "", query_path: str = None, query: str = None, default_inputs: ["BaseMockTable"] = None
):
Expand All @@ -29,17 +51,19 @@ def table_meta(
"""

def decorator(cls):
parsed_query = ""
mock_meta_kwargs = {"table_ref": table_ref}

if query_path:
with open(query_path) as f:
parsed_query = f.read()
mock_meta_kwargs['query'] = f.read()
elif query:
parsed_query = query
mock_meta_kwargs['query'] = query

if default_inputs:
validate_input_mocks(default_inputs)
mock_meta_kwargs["default_inputs"] = default_inputs

cls._sql_mock_data = SQLMockData(table_ref=table_ref, query=parsed_query, default_inputs=default_inputs or [])
cls._sql_mock_meta = MockTableMeta(**mock_meta_kwargs)
return cls

return decorator
Expand All @@ -53,19 +77,12 @@ class SQLMockData(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

default_inputs: List[SkipValidation["BaseMockTable"]] = None
columns: dict[str, Type[ColumnMock]] = None
data: list[dict] = None
input_data: list[dict] = None
table_ref: str = None
query: str = None
rendered_query: str = None
last_query: str = None

@property
def cte_name(self):
if getattr(self, "table_ref", None):
return self.table_ref.replace(".", "__")


class BaseMockTable:
Expand All @@ -81,6 +98,7 @@ class BaseMockTable:
"""

_sql_mock_data: SQLMockData = None
_sql_mock_meta: MockTableMeta = None
_sql_dialect: str = None

def __init__(self, data: list[dict] = None, sql_mock_data: SQLMockData = None) -> None:
Expand Down Expand Up @@ -125,19 +143,19 @@ def from_mocks(
Arguments:
input_data: List of MockTable instances that hold static data that should be used as inputs.
query_template_kwargs: Dictionary of Jinja template key-value pairs that should be used to render the query.
query: String of the SQL query that is used to generate the model. Can be a Jinja template. If provided, it overwrites the query on cls._sql_mock_data.query.
query: String of the SQL query that is used to generate the model. Can be a Jinja template. If provided, it overwrites the query on cls._sql_mock_meta.query.
"""
instance = cls(data=[])
query_template = Template(query or cls._sql_mock_data.query)
query_template = Template(query or cls._sql_mock_meta.query)
query = query_template.render(query_template_kwargs or {})
instance._sql_mock_data.rendered_query = query

# Update defaults with provided data. We use the table ref dictionaries to avoid duplicated inputs.
if getattr(cls._sql_mock_data, "default_inputs", None):
if getattr(cls._sql_mock_meta, "default_inputs", None):
default_inputs = {
mock_table._sql_mock_data.table_ref: mock_table for mock_table in cls._sql_mock_data.default_inputs
mock_table._sql_mock_meta.table_ref: mock_table for mock_table in cls._sql_mock_meta.default_inputs
}
input_dict = {mock_table._sql_mock_data.table_ref: mock_table for mock_table in input_data}
input_dict = {mock_table._sql_mock_meta.table_ref: mock_table for mock_table in input_data}
input_data = list({**default_inputs, **input_dict}.values())

validate_input_mocks(input_data)
Expand Down Expand Up @@ -234,7 +252,7 @@ def as_sql_input(self):

# Indent whole CTE content for better query readability
snippet = indent(f"SELECT {snippet}", "\t")
return f"{self._sql_mock_data.cte_name} AS (\n{snippet}\n)"
return f"{self._sql_mock_meta.cte_name} AS (\n{snippet}\n)"

def _assert_equal(
self,
Expand Down Expand Up @@ -265,7 +283,7 @@ def _assert_equal(
assert expected == data
except Exception as e:
if print_query_on_fail:
pass
print(self._sql_mock_data.last_query)
raise e

def assert_cte_equal(
Expand Down
2 changes: 1 addition & 1 deletion tests/sql_mock/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MockTestTable(BaseMockTable):
class TestReplaceOriginalTableReference:
def test_replace_original_table_references_when_reference_exists(self):
"""...then the original table reference should be replaced with the mocked table reference"""
query = f"SELECT * FROM {MockTestTable._sql_mock_data.table_ref}"
query = f"SELECT * FROM {MockTestTable._sql_mock_meta.table_ref}"
mock_tables = [MockTestTable()]
# Note that sqlglot will add a comment with the original table name at the end
expected = "SELECT\n *\nFROM data__mock_test_table /* data.mock_test_table */"
Expand Down
6 changes: 3 additions & 3 deletions tests/sql_mock/test_table_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_query_path_provided(mocker):
class TestMock(BaseMockTable):
pass

assert TestMock._sql_mock_data.query == query
assert TestMock._sql_mock_meta.query == query
mock_open.assert_called_with(query_path)


Expand All @@ -24,7 +24,7 @@ def test_no_query_path_provided():
class TestMock(BaseMockTable):
pass

assert TestMock._sql_mock_data.query == ""
assert TestMock._sql_mock_meta.query is None


def test_table_ref_provided():
Expand All @@ -35,4 +35,4 @@ def test_table_ref_provided():
class TestMock(BaseMockTable):
pass

assert TestMock._sql_mock_data.table_ref == table_ref
assert TestMock._sql_mock_meta.table_ref == table_ref
4 changes: 2 additions & 2 deletions tests/sql_mock/test_table_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_from_mocks(self, base_mock_table_instance, mocker):

def test_from_mocks_with_defaults(self, base_mock_table_instance, mocker):
query = "SELECT * FROM some_table"
input_data = [*MockTestTableWithDefaults._sql_mock_data.default_inputs, base_mock_table_instance]
input_data = [*MockTestTableWithDefaults._sql_mock_meta.default_inputs, base_mock_table_instance]
query_template_kwargs = {}
mocked_validate_input_mocks_for_query = mocker.patch(
"sql_mock.table_mocks.validate_all_input_mocks_for_query_provided"
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_as_sql_input():
]
sql_input = mock_table_instance.as_sql_input()
expected = (
f"{mock_table_instance._sql_mock_data.table_ref} AS (\n"
f"{mock_table_instance._sql_mock_meta.table_ref} AS (\n"
"\tSELECT cast('1' AS Integer) AS col1, cast('value1' AS String) AS col2\n"
"\tUNION ALL\n"
"\tSELECT cast('2' AS Integer) AS col1, cast('value2' AS String) AS col2\n"
Expand Down
8 changes: 4 additions & 4 deletions tests/test_table_mocks/test_generate_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_generate_query_no_cte_provided(mocker):
# Arrange
mock_table_instance = MockTestTable.from_dicts([])
mock_table_instance._sql_mock_data.input_data = [mock_table_instance]
original_query = f"SELECT * FROM {mock_table_instance._sql_mock_data.table_ref}"
original_query = f"SELECT * FROM {mock_table_instance._sql_mock_meta.table_ref}"
cte_to_select = "some_cte"
cte_adjusted_query = f"SELECT * FROM {cte_to_select}"
dummy_return_query = "SELECT foo FROM bar"
Expand All @@ -42,7 +42,7 @@ def test_generate_query_no_cte_provided(mocker):

expected_query_template_result = dedent(
f"""
WITH {mock_table_instance._sql_mock_data.cte_name} AS (
WITH {mock_table_instance._sql_mock_meta.cte_name} AS (
\tSELECT cast('1' AS Integer) AS col1, cast('hey' AS String) AS col2 FROM (SELECT 1) WHERE FALSE
),
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_generate_query_cte_provided(mocker):
# Arrange
mock_table_instance = MockTestTable.from_dicts([])
mock_table_instance._sql_mock_data.input_data = [mock_table_instance]
original_query = f"SELECT * FROM {mock_table_instance._sql_mock_data.table_ref}"
original_query = f"SELECT * FROM {mock_table_instance._sql_mock_meta.table_ref}"
cte_to_select = "some_cte"
cte_adjusted_query = f"SELECT * FROM {cte_to_select}"
dummy_return_query = "SELECT foo FROM bar"
Expand All @@ -87,7 +87,7 @@ def test_generate_query_cte_provided(mocker):

expected_query_template_result = dedent(
f"""
WITH {mock_table_instance._sql_mock_data.cte_name} AS (
WITH {mock_table_instance._sql_mock_meta.cte_name} AS (
\tSELECT cast('1' AS Integer) AS col1, cast('hey' AS String) AS col2 FROM (SELECT 1) WHERE FALSE
),
Expand Down

0 comments on commit 20059d7

Please sign in to comment.