Skip to content

Commit

Permalink
updates to shuffle metadata and realias
Browse files Browse the repository at this point in the history
  • Loading branch information
wongjingping committed Jun 7, 2024
1 parent 2738243 commit 2ce6130
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
36 changes: 26 additions & 10 deletions defog_utils/utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,12 +605,18 @@ def shuffle_table_metadata(md_str: str, seed: Optional[int] = None) -> str:
f"join_statements does not contain 'join':\n\"{join_statements}\"\ndropping this join statement."
)
join_statements = ""
# remove create schema statements if present
if "CREATE SCHEMA" in md_str:
schema_str_list = []
while "CREATE SCHEMA" in md_str:
logging.debug(
f"md_str contains CREATE SCHEMA statements\nmd_str: {md_str}\nshuffling columns within tables only"
)
md_str = re.sub("CREATE SCHEMA.*?;", "", md_str)
# use regex to find and extract line with CREATE SCHEMA
schema_match = re.search(r"CREATE SCHEMA.*;", md_str)
if schema_match:
schema_line = schema_match.group(0)
schema_str_list.append(schema_line)
md_str = md_str.replace(schema_line, "")

md_table_list = md_str.split(");")
md_table_shuffled_list = []
for md_table in md_table_list:
Expand Down Expand Up @@ -641,6 +647,8 @@ def shuffle_table_metadata(md_str: str, seed: Optional[int] = None) -> str:
md_table_shuffled_str += (
"\nHere is a list of joinable columns:\n" + join_statements
)
if schema_str_list != []:
md_table_shuffled_str = "\n".join(schema_str_list) + "\n" + md_table_shuffled_str
return md_table_shuffled_str


Expand All @@ -651,7 +659,7 @@ def replace_alias(
Replaces the table aliases in the SQL query with the new aliases provided in the new_alias_map.
`new_alias_map` is a dict of table_name -> new_alias.
"""
parsed = parse_one(sql)
parsed = parse_one(sql, dialect=dialect)
existing_alias_map = {}
# save and update the existing table aliases
for node in parsed.walk():
Expand All @@ -664,12 +672,20 @@ def replace_alias(
if table_name in new_alias_map:
node.set("alias", new_alias_map[table_name])
else:
existing_alias_map[table_name] = table_name
node.set("alias", new_alias_map.get(table_name, table_name))
# go through each column, and if it has a table alias, replace it with the new alias
for node in parsed.walk():
if isinstance(node, exp.Column):
if node.table and node.table in existing_alias_map:
original_table_name = existing_alias_map[node.table]
if original_table_name in new_alias_map:
node.set("table", new_alias_map[original_table_name])
return parsed.sql(dialect)
print(f"{node}: {node.table}")
if node.table:
# if in existing alias map, set the table to the new alias
if node.table in existing_alias_map:
original_table_name = existing_alias_map[node.table]
print(f"original_table_name: {original_table_name}")
if original_table_name in new_alias_map:
node.set("table", new_alias_map[original_table_name])
# else if in new alias map, set the table to the new alias
elif node.table in new_alias_map:
print(f"Setting table to {new_alias_map[node.table]}")
node.set("table", new_alias_map[node.table])
return parsed.sql(dialect, normalize_functions="upper", comments=False)
34 changes: 26 additions & 8 deletions tests/test_utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ def test_shuffle_table_metadata_seed(self):
TEST_DB.PUBLIC.CUSTOMERS.CUSTOMER_ID can be joined with patient.ssn
"""
expected_md_shuffled = """CREATE TABLE physician (
expected_md_shuffled = """CREATE SCHEMA IF NOT EXISTS TEST_DB;
CREATE TABLE physician (
position character varying,
ssn integer, --social security number of the physician
name character varying, --name of the physician
Expand Down Expand Up @@ -894,20 +895,29 @@ def test_add_space_padding(self):


class TestReplaceAlias(unittest.TestCase):
def test_replace_alias(self):
sql = "SELECT a.name, b.age FROM users a, info b WHERE a.id = b.id"
def test_replace_alias_change_existing_alias(self):
# should replace users a with users u and info b with info i
sql = "SELECT a.name, b.age FROM users a JOIN info b ON a.id = b.id"
new_alias_map = {"users": "u", "info": "i"}
expected = "SELECT u.name, i.age FROM users AS u, info AS i WHERE u.id = i.id"
expected = "SELECT u.name, i.age FROM users AS u JOIN info AS i ON u.id = i.id"
self.assertEqual(replace_alias(sql, new_alias_map), expected)

def test_replace_alias_no_alias(self):
sql = "SELECT name, age FROM users JOIN info WHERE users.id = info.id"
# should add alias to tables if not present
sql = "SELECT name, age FROM users JOIN info ON users.id = info.id"
new_alias_map = {"users": "u", "info": "i"}
expected = "SELECT name, age FROM users AS u JOIN info AS i ON u.id = i.id"
self.assertEqual(replace_alias(sql, new_alias_map), expected)

def test_replace_alias_no_table_alias_have_column_table(self):
# should replace alias in columns using new_alias_map and add alias to tables
sql = "SELECT users.name, info.age FROM users JOIN info ON users.id = info.id"
new_alias_map = {"users": "u", "info": "i"}
expected = "SELECT name, age FROM users, info WHERE u.id = i.id"
expected = "SELECT u.name, i.age FROM users AS u JOIN info AS i ON u.id = i.id"
self.assertEqual(replace_alias(sql, new_alias_map), expected)

def test_replace_alias_no_change(self):
sql = "SELECT a.name, b.age FROM users AS a, info AS b WHERE a.id = b.id"
sql = "SELECT a.name, b.age FROM users AS a JOIN info AS b ON a.id = b.id"
new_alias_map = {"users": "a", "logs": "l"}
expected = sql
self.assertEqual(replace_alias(sql, new_alias_map), expected)
Expand All @@ -931,7 +941,15 @@ def test_sql_1(self):
"game_events": "ge",
"player_stats": "ps",
}
expected = """WITH player_games AS (SELECT g.gm_game_id AS game_id, p.pl_player_id AS player_id, p.pl_team_id AS team_id FROM games AS g JOIN players AS p ON g.gm_home_team_id = p.pl_team_id OR g.gm_away_team_id = p.pl_team_id), unique_games AS (SELECT DISTINCT gm_game_id FROM games) SELECT CAST(COUNT(DISTINCT pg.game_id) AS REAL) / NULLIF(COUNT(DISTINCT ug.gm_game_id), 0) AS fraction FROM player_games AS pg RIGHT JOIN unique_games AS ug ON pg.game_id = ug.gm_game_id"""
expected = """WITH player_games AS (SELECT g.gm_game_id AS game_id, p.pl_player_id AS player_id, p.pl_team_id AS team_id FROM games AS g JOIN players AS p ON g.gm_home_team_id = p.pl_team_id OR g.gm_away_team_id = p.pl_team_id), unique_games AS (SELECT DISTINCT gm_game_id FROM games AS g) SELECT CAST(COUNT(DISTINCT pg.game_id) AS DOUBLE PRECISION) / NULLIF(COUNT(DISTINCT ug.gm_game_id), 0) AS fraction FROM player_games AS pg RIGHT JOIN unique_games AS ug ON pg.game_id = ug.gm_game_id"""
result = replace_alias(sql, new_alias_map)
print(result)
self.assertEqual(result, expected)

def test_sql_2(self):
sql = "SELECT train.year, manufacturer, AVG(train.capacity) AS average_capacity FROM train WHERE train.manufacturer ILIKE '%Mfr 1%' GROUP BY train.year, train.manufacturer ORDER BY train.year;"
new_alias_map = {"train": "t"}
expected = "SELECT t.year, manufacturer, AVG(t.capacity) AS average_capacity FROM train AS t WHERE t.manufacturer ILIKE '%Mfr 1%' GROUP BY t.year, t.manufacturer ORDER BY t.year"
result = replace_alias(sql, new_alias_map)
print(result)
self.assertEqual(result, expected)
Expand Down

0 comments on commit 2ce6130

Please sign in to comment.