Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jhilgart committed Apr 18, 2024
1 parent f348886 commit 9cdb945
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
30 changes: 12 additions & 18 deletions semantic_model_generator/sqlgen/generate_sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Set, Union
from typing import Optional, Union

import sqlglot
from sqlglot.dialects.snowflake import Snowflake
Expand Down Expand Up @@ -42,15 +42,17 @@ def _convert_to_snowflake_sql(sql: str) -> str:
return expression.sql()


def _create_select_statement(table: Table, cols: Set[str], limit: int) -> str:
def _create_select_statement(table: Table, limit: int) -> str:
def _return_col_or_expr(
col: Union[TimeDimension, Dimension, Measure], cols: Set[str]
col: Union[TimeDimension, Dimension, Measure]
) -> Optional[str]:
# TODO(jhilgart): Handle quoted names properly.
if col.name.lower() not in cols:
return None
if " " in col.name:
raise ValueError(
f"Column names should not have spaces in them. Passed = {col.name}"
)
expr = (
f'{col.expr} as "{col.name}"'
f"{col.expr} as {col.name}"
if col.expr.lower() != col.name.lower()
else f"{col.expr}"
)
Expand All @@ -60,11 +62,11 @@ def _return_col_or_expr(

columns = []
for dim_col in table.dimensions:
columns.append(_return_col_or_expr(dim_col, cols))
columns.append(_return_col_or_expr(dim_col))
for time_col in table.measures:
columns.append(_return_col_or_expr(time_col, cols))
columns.append(_return_col_or_expr(time_col))
for time_dim_col in table.time_dimensions:
columns.append(_return_col_or_expr(time_dim_col, cols))
columns.append(_return_col_or_expr(time_dim_col))

filtered_columns = [item for item in columns if item is not None]
if len(filtered_columns) == 0:
Expand All @@ -89,15 +91,7 @@ def generate_select_with_all_cols(table: Table, limit: int) -> str:
Returns:
str: A SQL statement formatted for Snowflake.
"""
cols = []
for time_dim_col in table.time_dimensions:
cols.append(time_dim_col.name.lower())
for dim_col in table.dimensions:
cols.append(dim_col.name.lower())
for meausre_col in table.measures:
cols.append(meausre_col.name.lower())

col_set = set(cols)
select = _create_select_statement(table, col_set, limit)
select = _create_select_statement(table, limit)

return _convert_to_snowflake_sql(select)
34 changes: 26 additions & 8 deletions semantic_model_generator/tests/generate_sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@
description="Geographical region",
)

dimension_example_invalid_name = Dimension(
name="Regions In The World",
synonyms=["Area", "Locale"],
description="Geographical region",
expr="region_code",
data_type="string",
unique=False,
sample_values=["North", "South", "East", "West"],
)

time_dimension_example = TimeDimension(
name="Date",
synonyms=["Time"],
Expand All @@ -36,7 +46,7 @@
)

measure_example = Measure(
name="Total Sales",
name="Total_Sales",
synonyms=["Sales", "Revenue"],
description="Total sales amount",
expr="sales_amount - sales_total",
Expand All @@ -60,27 +70,26 @@
measures=[measure_example],
)

_TEST_VALID_TABLE = Table(

_TEST_TABLE_NO_COLS = Table(
name="Transactions",
synonyms=["Transaction Records"],
description="Table containing transaction records",
base_table=fully_qualified_table_example,
dimensions=[dimension_example],
time_dimensions=[time_dimension_example],
measures=[measure_example],
dimensions=[dimension_example_no_cols],
)

_TEST_TABLE_NO_COLS = Table(
_TEST_TABLE_INVALID_NAME = Table(
name="Transactions",
synonyms=["Transaction Records"],
description="Table containing transaction records",
base_table=fully_qualified_table_example,
dimensions=[dimension_example_no_cols],
dimensions=[dimension_example_invalid_name],
)


def test_valid_table_sql_with_expr():
want = 'SELECT region_code AS "Region", sales_amount - sales_total AS "Total Sales", transaction_date AS "Date" FROM SalesDB.public.transactions LIMIT 100'
want = "SELECT region_code AS Region, sales_amount - sales_total AS Total_Sales, transaction_date AS Date FROM SalesDB.public.transactions LIMIT 100"
generated_sql = generate_select_with_all_cols(_TEST_VALID_TABLE, 100)
assert generated_sql == want

Expand All @@ -92,3 +101,12 @@ def test_table_no_cols():
str(excinfo.value)
== "No columns found for table Transactions. Please remove this"
)


def test_table_invalid_col_name():
with pytest.raises(ValueError) as excinfo:
_ = generate_select_with_all_cols(_TEST_TABLE_INVALID_NAME, 100)
assert (
str(excinfo.value)
== "Column names should not have spaces in them. Passed = Regions In The World"
)

0 comments on commit 9cdb945

Please sign in to comment.