Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jhilgart committed Apr 15, 2024
1 parent ea795de commit 273e5e8
Showing 1 changed file with 285 additions and 0 deletions.
285 changes: 285 additions & 0 deletions semantic_model_generator/tests/generate_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
from unittest.mock import MagicMock, mock_open, patch

import pandas as pd
import pytest
import yaml

from semantic_model_generator.data_processing import proto_utils
from semantic_model_generator.data_processing.data_types import Column, Table
from semantic_model_generator.generate_model import (
_to_snake_case,
generate_base_semantic_context_from_snowflake,
raw_schema_to_semantic_context,
)
from semantic_model_generator.protos import semantic_model_pb2
from semantic_model_generator.snowflake_utils.snowflake_connector import (
SnowflakeConnector,
)


def test_to_snake_case():
text = "Hello World-How are_you"

assert "hello_world_how_are_you" == _to_snake_case(text)


@pytest.fixture
def mock_snowflake_connection():
"""Fixture to mock the snowflake_connection function."""
with patch(
"semantic_model_generator.snowflake_utils.snowflake_connector.snowflake_connection"
) as mock:
mock.return_value = MagicMock()
yield mock


_CONVERTED_TABLE_ALIAS = Table(
id_=0,
name="ALIAS",
columns=[
Column(
id_=0,
column_name="ZIP_CODE",
column_type="TEXT",
values=None,
comment=None,
),
Column(
id_=1,
column_name="AREA_CODE",
column_type="NUMBER",
values=None,
comment=None,
),
Column(
id_=2,
column_name="BAD_ALIAS",
column_type="TIMESTAMP",
values=None,
comment=None,
),
Column(
id_=3,
column_name="CBSA",
column_type="NUMBER",
values=None,
comment=None,
),
],
comment=None,
)

_CONVERTED_TABLE_ZIP_CODE = Table(
id_=0,
name="PRODUCTS",
columns=[
Column(
id_=0,
column_name="SKU",
column_type="NUMBER",
values=["1", "2", "3"],
comment=None,
),
],
comment=None,
)


@pytest.fixture
def mock_snowflake_connection_env(monkeypatch):
# Mock environment variable
monkeypatch.setenv("SNOWFLAKE_HOST", "test_host")

# Use this fixture to also patch instance methods if needed
with patch.object(
SnowflakeConnector, "_get_user", return_value="test_user"
), patch.object(
SnowflakeConnector, "_get_password", return_value="test_password"
), patch.object(
SnowflakeConnector, "_get_role", return_value="test_role"
), patch.object(
SnowflakeConnector, "_get_warehouse", return_value="test_warehouse"
), patch.object(
SnowflakeConnector, "_get_host", return_value="test_host"
):
yield


@pytest.fixture
def mock_dependencies(mock_snowflake_connection):
valid_schemas_tables_columns_df_alias = pd.DataFrame(
{
"TABLE_NAME": ["ALIAS"] * 4,
"COLUMN_NAME": ["ZIP_CODE", "AREA_CODE", "BAD_ALIAS", "CBSA"],
"DATA_TYPE": ["VARCHAR", "INTEGER", "DATETIME", "DECIMAL"],
}
)
valid_schemas_tables_columns_df_zip_code = pd.DataFrame(
{
"TABLE_NAME": ["PRODUCTS"],
"COLUMN_NAME": ["SKU"],
"DATA_TYPE": ["NUMBER"],
}
)
valid_schemas_tables_representations = [
valid_schemas_tables_columns_df_alias,
valid_schemas_tables_columns_df_zip_code,
]
table_representations = [
_CONVERTED_TABLE_ALIAS, # Value returned on the first call.
_CONVERTED_TABLE_ZIP_CODE, # Value returned on the second call.
]

with patch(
"semantic_model_generator.main.get_valid_schemas_tables_columns_df",
side_effect=valid_schemas_tables_representations,
), patch(
"semantic_model_generator.main.get_table_representation",
side_effect=table_representations,
):
yield


def test_raw_schema_to_semantic_context(
mock_dependencies, mock_snowflake_connection, mock_snowflake_connection_env
):
want_yaml = "name: this is the best semantic model ever\ntables:\n - name: ALIAS\n description: ' '\n base_table:\n database: test_db\n schema: schema_test\n table: ALIAS\n filters:\n - name: ' '\n synonyms:\n - ' '\n description: ' '\n expr: ' '\n dimensions:\n - name: ZIP_CODE\n synonyms:\n - ' '\n description: ' '\n expr: ZIP_CODE\n data_type: TEXT\n time_dimensions:\n - name: BAD_ALIAS\n synonyms:\n - ' '\n description: ' '\n expr: BAD_ALIAS\n data_type: TIMESTAMP\n measures:\n - name: AREA_CODE\n synonyms:\n - ' '\n description: ' '\n expr: AREA_CODE\n data_type: NUMBER\n - name: CBSA\n synonyms:\n - ' '\n description: ' '\n expr: CBSA\n data_type: NUMBER\n"

snowflake_account = "test_account"
fqn_tables = ["test_db.schema_test.ALIAS"]
semantic_model_name = "this is the best semantic model ever"

semantic_model = raw_schema_to_semantic_context(
fqn_tables=fqn_tables,
snowflake_account=snowflake_account,
semantic_model_name=semantic_model_name,
)

# Assert the result as expected
assert isinstance(semantic_model, semantic_model_pb2.SemanticModel)
assert len(semantic_model.tables) > 0

result_yaml = proto_utils.proto_to_yaml(semantic_model)
assert result_yaml == want_yaml

mock_snowflake_connection.assert_called_once_with(
user="test_user",
password="test_password",
account="test_account",
role="test_role",
warehouse="test_warehouse",
host="test_host",
)


@patch("builtins.open", new_callable=mock_open)
def test_generate_base_context_with_placeholder_comments(
mock_file,
mock_dependencies,
mock_snowflake_connection,
mock_snowflake_connection_env,
):

fqn_tables = ["test_db.schema_test.ALIAS"]
snowflake_account = "test_account"
output_path = "output_model_path.yaml"
semantic_model_name = "my awesome semantic model"

generate_base_semantic_context_from_snowflake(
fqn_tables=fqn_tables,
snowflake_account=snowflake_account,
output_yaml_path=output_path,
semantic_model_name=semantic_model_name,
)

mock_file.assert_called_once_with(output_path, "w")
# Assert file save called with placeholder comments added.
mock_file().write.assert_called_once_with(
"name: my awesome semantic model\ntables:\n - name: ALIAS\n description: ' ' # <FILL-OUT>\n base_table:\n database: test_db\n schema: schema_test\n table: ALIAS\n filters:\n - name: ' ' # <FILL-OUT>\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: ' ' # <FILL-OUT>\n dimensions:\n - name: ZIP_CODE\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: ZIP_CODE\n data_type: TEXT\n time_dimensions:\n - name: BAD_ALIAS\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: BAD_ALIAS\n data_type: TIMESTAMP\n measures:\n - name: AREA_CODE\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: AREA_CODE\n data_type: NUMBER\n - name: CBSA\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: CBSA\n data_type: NUMBER\n"
)


@patch("builtins.open", new_callable=mock_open)
def test_generate_base_context_with_placeholder_comments_cross_database_cross_schema(
mock_file,
mock_dependencies,
mock_snowflake_connection,
mock_snowflake_connection_env,
):

fqn_tables = [
"test_db.schema_test.ALIAS",
"a_different_database.a_different_schema.PRODUCTS",
]
snowflake_account = "test_account"
output_path = "output_model_path.yaml"
semantic_model_name = "Another Incredible Semantic Model"

generate_base_semantic_context_from_snowflake(
fqn_tables=fqn_tables,
snowflake_account=snowflake_account,
output_yaml_path=output_path,
semantic_model_name=semantic_model_name,
)

mock_file.assert_called_once_with(output_path, "w")
# Assert file save called with placeholder comments added along with sample values and cross-database
mock_file().write.assert_called_once_with(
"name: Another Incredible Semantic Model\ntables:\n - name: ALIAS\n description: ' ' # <FILL-OUT>\n base_table:\n database: test_db\n schema: schema_test\n table: ALIAS\n filters:\n - name: ' ' # <FILL-OUT>\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: ' ' # <FILL-OUT>\n dimensions:\n - name: ZIP_CODE\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: ZIP_CODE\n data_type: TEXT\n time_dimensions:\n - name: BAD_ALIAS\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: BAD_ALIAS\n data_type: TIMESTAMP\n measures:\n - name: AREA_CODE\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: AREA_CODE\n data_type: NUMBER\n - name: CBSA\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: CBSA\n data_type: NUMBER\n - name: PRODUCTS\n description: ' ' # <FILL-OUT>\n base_table:\n database: a_different_database\n schema: a_different_schema\n table: PRODUCTS\n filters:\n - name: ' ' # <FILL-OUT>\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: ' ' # <FILL-OUT>\n measures:\n - name: SKU\n synonyms:\n - ' ' # <FILL-OUT>\n description: ' ' # <FILL-OUT>\n expr: SKU\n data_type: NUMBER\n sample_values:\n - '1'\n - '2'\n - '3'\n"
)


def test_semantic_model_to_yaml() -> None:
want_yaml = "name: transaction_ctx\ntables:\n - name: transactions\n description: A table containing data about financial transactions. Each row contains\n details of a financial transaction.\n base_table:\n database: my_database\n schema: my_schema\n table: transactions\n dimensions:\n - name: transaction_id\n description: A unique id for this transaction.\n expr: transaction_id\n data_type: BIGINT\n unique: true\n time_dimensions:\n - name: initiation_date\n description: Timestamp when the transaction was initiated. In UTC.\n expr: initiation_date\n data_type: DATETIME\n measures:\n - name: amount\n description: The amount of this transaction.\n expr: amount\n data_type: DECIMAL\n default_aggregation: sum\n"
got = semantic_model_pb2.SemanticModel(
name="transaction_ctx",
tables=[
semantic_model_pb2.Table(
name="transactions",
description="A table containing data about financial transactions. Each row contains details of a financial transaction.",
base_table=semantic_model_pb2.FullyQualifiedTable(
database="my_database",
schema="my_schema",
table="transactions",
),
time_dimensions=[
semantic_model_pb2.TimeDimension(
name="initiation_date",
description="Timestamp when the transaction was initiated. In UTC.",
expr="initiation_date",
data_type="DATETIME",
unique=False,
)
],
measures=[
semantic_model_pb2.Measure(
name="amount",
description="The amount of this transaction.",
expr="amount",
data_type="DECIMAL",
default_aggregation=semantic_model_pb2.AggregationType.sum,
),
],
dimensions=[
semantic_model_pb2.Dimension(
name="transaction_id",
description="A unique id for this transaction.",
expr="transaction_id",
data_type="BIGINT",
unique=True,
)
],
)
],
)
got_yaml = proto_utils.proto_to_yaml(got)
assert got_yaml == want_yaml

# Parse the YAML strings into Python data structures
want_data = yaml.safe_load(want_yaml)
got_data = yaml.safe_load(got_yaml)

# Now compare the data structures
assert (
want_data == got_data
), "The generated YAML does not match the expected structure."

0 comments on commit 273e5e8

Please sign in to comment.