From 273e5e80bb589150a0151b13ce86fdf65df8d526 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 15 Apr 2024 14:36:52 -0700 Subject: [PATCH] tests --- .../tests/generate_model_test.py | 285 ++++++++++++++++++ 1 file changed, 285 insertions(+) create mode 100644 semantic_model_generator/tests/generate_model_test.py diff --git a/semantic_model_generator/tests/generate_model_test.py b/semantic_model_generator/tests/generate_model_test.py new file mode 100644 index 00000000..7a4143f4 --- /dev/null +++ b/semantic_model_generator/tests/generate_model_test.py @@ -0,0 +1,285 @@ +from unittest.mock import MagicMock, mock_open, patch + +import pandas as pd +import pytest +import yaml + +from semantic_model_generator.data_processing import proto_utils +from semantic_model_generator.data_processing.data_types import Column, Table +from semantic_model_generator.generate_model import ( + _to_snake_case, + generate_base_semantic_context_from_snowflake, + raw_schema_to_semantic_context, +) +from semantic_model_generator.protos import semantic_model_pb2 +from semantic_model_generator.snowflake_utils.snowflake_connector import ( + SnowflakeConnector, +) + + +def test_to_snake_case(): + text = "Hello World-How are_you" + + assert "hello_world_how_are_you" == _to_snake_case(text) + + +@pytest.fixture +def mock_snowflake_connection(): + """Fixture to mock the snowflake_connection function.""" + with patch( + "semantic_model_generator.snowflake_utils.snowflake_connector.snowflake_connection" + ) as mock: + mock.return_value = MagicMock() + yield mock + + +_CONVERTED_TABLE_ALIAS = Table( + id_=0, + name="ALIAS", + columns=[ + Column( + id_=0, + column_name="ZIP_CODE", + column_type="TEXT", + values=None, + comment=None, + ), + Column( + id_=1, + column_name="AREA_CODE", + column_type="NUMBER", + values=None, + comment=None, + ), + Column( + id_=2, + column_name="BAD_ALIAS", + column_type="TIMESTAMP", + values=None, + comment=None, + ), + Column( + id_=3, + column_name="CBSA", + column_type="NUMBER", + values=None, + comment=None, + ), + ], + comment=None, +) + +_CONVERTED_TABLE_ZIP_CODE = Table( + id_=0, + name="PRODUCTS", + columns=[ + Column( + id_=0, + column_name="SKU", + column_type="NUMBER", + values=["1", "2", "3"], + comment=None, + ), + ], + comment=None, +) + + +@pytest.fixture +def mock_snowflake_connection_env(monkeypatch): + # Mock environment variable + monkeypatch.setenv("SNOWFLAKE_HOST", "test_host") + + # Use this fixture to also patch instance methods if needed + with patch.object( + SnowflakeConnector, "_get_user", return_value="test_user" + ), patch.object( + SnowflakeConnector, "_get_password", return_value="test_password" + ), patch.object( + SnowflakeConnector, "_get_role", return_value="test_role" + ), patch.object( + SnowflakeConnector, "_get_warehouse", return_value="test_warehouse" + ), patch.object( + SnowflakeConnector, "_get_host", return_value="test_host" + ): + yield + + +@pytest.fixture +def mock_dependencies(mock_snowflake_connection): + valid_schemas_tables_columns_df_alias = pd.DataFrame( + { + "TABLE_NAME": ["ALIAS"] * 4, + "COLUMN_NAME": ["ZIP_CODE", "AREA_CODE", "BAD_ALIAS", "CBSA"], + "DATA_TYPE": ["VARCHAR", "INTEGER", "DATETIME", "DECIMAL"], + } + ) + valid_schemas_tables_columns_df_zip_code = pd.DataFrame( + { + "TABLE_NAME": ["PRODUCTS"], + "COLUMN_NAME": ["SKU"], + "DATA_TYPE": ["NUMBER"], + } + ) + valid_schemas_tables_representations = [ + valid_schemas_tables_columns_df_alias, + valid_schemas_tables_columns_df_zip_code, + ] + table_representations = [ + _CONVERTED_TABLE_ALIAS, # Value returned on the first call. + _CONVERTED_TABLE_ZIP_CODE, # Value returned on the second call. + ] + + with patch( + "semantic_model_generator.main.get_valid_schemas_tables_columns_df", + side_effect=valid_schemas_tables_representations, + ), patch( + "semantic_model_generator.main.get_table_representation", + side_effect=table_representations, + ): + yield + + +def test_raw_schema_to_semantic_context( + mock_dependencies, mock_snowflake_connection, mock_snowflake_connection_env +): + want_yaml = "name: this is the best semantic model ever\ntables:\n - name: ALIAS\n description: ' '\n base_table:\n database: test_db\n schema: schema_test\n table: ALIAS\n filters:\n - name: ' '\n synonyms:\n - ' '\n description: ' '\n expr: ' '\n dimensions:\n - name: ZIP_CODE\n synonyms:\n - ' '\n description: ' '\n expr: ZIP_CODE\n data_type: TEXT\n time_dimensions:\n - name: BAD_ALIAS\n synonyms:\n - ' '\n description: ' '\n expr: BAD_ALIAS\n data_type: TIMESTAMP\n measures:\n - name: AREA_CODE\n synonyms:\n - ' '\n description: ' '\n expr: AREA_CODE\n data_type: NUMBER\n - name: CBSA\n synonyms:\n - ' '\n description: ' '\n expr: CBSA\n data_type: NUMBER\n" + + snowflake_account = "test_account" + fqn_tables = ["test_db.schema_test.ALIAS"] + semantic_model_name = "this is the best semantic model ever" + + semantic_model = raw_schema_to_semantic_context( + fqn_tables=fqn_tables, + snowflake_account=snowflake_account, + semantic_model_name=semantic_model_name, + ) + + # Assert the result as expected + assert isinstance(semantic_model, semantic_model_pb2.SemanticModel) + assert len(semantic_model.tables) > 0 + + result_yaml = proto_utils.proto_to_yaml(semantic_model) + assert result_yaml == want_yaml + + mock_snowflake_connection.assert_called_once_with( + user="test_user", + password="test_password", + account="test_account", + role="test_role", + warehouse="test_warehouse", + host="test_host", + ) + + +@patch("builtins.open", new_callable=mock_open) +def test_generate_base_context_with_placeholder_comments( + mock_file, + mock_dependencies, + mock_snowflake_connection, + mock_snowflake_connection_env, +): + + fqn_tables = ["test_db.schema_test.ALIAS"] + snowflake_account = "test_account" + output_path = "output_model_path.yaml" + semantic_model_name = "my awesome semantic model" + + generate_base_semantic_context_from_snowflake( + fqn_tables=fqn_tables, + snowflake_account=snowflake_account, + output_yaml_path=output_path, + semantic_model_name=semantic_model_name, + ) + + mock_file.assert_called_once_with(output_path, "w") + # Assert file save called with placeholder comments added. + mock_file().write.assert_called_once_with( + "name: my awesome semantic model\ntables:\n - name: ALIAS\n description: ' ' # \n base_table:\n database: test_db\n schema: schema_test\n table: ALIAS\n filters:\n - name: ' ' # \n synonyms:\n - ' ' # \n description: ' ' # \n expr: ' ' # \n dimensions:\n - name: ZIP_CODE\n synonyms:\n - ' ' # \n description: ' ' # \n expr: ZIP_CODE\n data_type: TEXT\n time_dimensions:\n - name: BAD_ALIAS\n synonyms:\n - ' ' # \n description: ' ' # \n expr: BAD_ALIAS\n data_type: TIMESTAMP\n measures:\n - name: AREA_CODE\n synonyms:\n - ' ' # \n description: ' ' # \n expr: AREA_CODE\n data_type: NUMBER\n - name: CBSA\n synonyms:\n - ' ' # \n description: ' ' # \n expr: CBSA\n data_type: NUMBER\n" + ) + + +@patch("builtins.open", new_callable=mock_open) +def test_generate_base_context_with_placeholder_comments_cross_database_cross_schema( + mock_file, + mock_dependencies, + mock_snowflake_connection, + mock_snowflake_connection_env, +): + + fqn_tables = [ + "test_db.schema_test.ALIAS", + "a_different_database.a_different_schema.PRODUCTS", + ] + snowflake_account = "test_account" + output_path = "output_model_path.yaml" + semantic_model_name = "Another Incredible Semantic Model" + + generate_base_semantic_context_from_snowflake( + fqn_tables=fqn_tables, + snowflake_account=snowflake_account, + output_yaml_path=output_path, + semantic_model_name=semantic_model_name, + ) + + mock_file.assert_called_once_with(output_path, "w") + # Assert file save called with placeholder comments added along with sample values and cross-database + mock_file().write.assert_called_once_with( + "name: Another Incredible Semantic Model\ntables:\n - name: ALIAS\n description: ' ' # \n base_table:\n database: test_db\n schema: schema_test\n table: ALIAS\n filters:\n - name: ' ' # \n synonyms:\n - ' ' # \n description: ' ' # \n expr: ' ' # \n dimensions:\n - name: ZIP_CODE\n synonyms:\n - ' ' # \n description: ' ' # \n expr: ZIP_CODE\n data_type: TEXT\n time_dimensions:\n - name: BAD_ALIAS\n synonyms:\n - ' ' # \n description: ' ' # \n expr: BAD_ALIAS\n data_type: TIMESTAMP\n measures:\n - name: AREA_CODE\n synonyms:\n - ' ' # \n description: ' ' # \n expr: AREA_CODE\n data_type: NUMBER\n - name: CBSA\n synonyms:\n - ' ' # \n description: ' ' # \n expr: CBSA\n data_type: NUMBER\n - name: PRODUCTS\n description: ' ' # \n base_table:\n database: a_different_database\n schema: a_different_schema\n table: PRODUCTS\n filters:\n - name: ' ' # \n synonyms:\n - ' ' # \n description: ' ' # \n expr: ' ' # \n measures:\n - name: SKU\n synonyms:\n - ' ' # \n description: ' ' # \n expr: SKU\n data_type: NUMBER\n sample_values:\n - '1'\n - '2'\n - '3'\n" + ) + + +def test_semantic_model_to_yaml() -> None: + want_yaml = "name: transaction_ctx\ntables:\n - name: transactions\n description: A table containing data about financial transactions. Each row contains\n details of a financial transaction.\n base_table:\n database: my_database\n schema: my_schema\n table: transactions\n dimensions:\n - name: transaction_id\n description: A unique id for this transaction.\n expr: transaction_id\n data_type: BIGINT\n unique: true\n time_dimensions:\n - name: initiation_date\n description: Timestamp when the transaction was initiated. In UTC.\n expr: initiation_date\n data_type: DATETIME\n measures:\n - name: amount\n description: The amount of this transaction.\n expr: amount\n data_type: DECIMAL\n default_aggregation: sum\n" + got = semantic_model_pb2.SemanticModel( + name="transaction_ctx", + tables=[ + semantic_model_pb2.Table( + name="transactions", + description="A table containing data about financial transactions. Each row contains details of a financial transaction.", + base_table=semantic_model_pb2.FullyQualifiedTable( + database="my_database", + schema="my_schema", + table="transactions", + ), + time_dimensions=[ + semantic_model_pb2.TimeDimension( + name="initiation_date", + description="Timestamp when the transaction was initiated. In UTC.", + expr="initiation_date", + data_type="DATETIME", + unique=False, + ) + ], + measures=[ + semantic_model_pb2.Measure( + name="amount", + description="The amount of this transaction.", + expr="amount", + data_type="DECIMAL", + default_aggregation=semantic_model_pb2.AggregationType.sum, + ), + ], + dimensions=[ + semantic_model_pb2.Dimension( + name="transaction_id", + description="A unique id for this transaction.", + expr="transaction_id", + data_type="BIGINT", + unique=True, + ) + ], + ) + ], + ) + got_yaml = proto_utils.proto_to_yaml(got) + assert got_yaml == want_yaml + + # Parse the YAML strings into Python data structures + want_data = yaml.safe_load(want_yaml) + got_data = yaml.safe_load(got_yaml) + + # Now compare the data structures + assert ( + want_data == got_data + ), "The generated YAML does not match the expected structure."