Skip to content

Commit

Permalink
Refactor schema generation logic in generate_sqlserver_schema
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Sep 16, 2024
1 parent f733d26 commit 7702c71
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions defog/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,6 @@ def generate_sqlserver_schema(
"SELECT table_catalog, table_schema, table_name FROM information_schema.tables;"
)

# remove all system tables
tables = []
for row in cur.fetchall():
if (
Expand Down Expand Up @@ -602,16 +601,17 @@ def generate_sqlserver_schema(

print("Getting schema for each table in your database...")
# get the schema for each table
for table_name in tables:
for orig_table_name in tables:
# if there are two dots, we have the database name and the schema name
if table_name.count(".") == 2:
db_name, schema, table_name = table_name.split(".", 2)
elif table_name.count(".") == 1:
schema, table_name = table_name.split(".", 1)
if orig_table_name.count(".") == 2:
db_name, schema, table_name = orig_table_name.split(".", 2)
elif orig_table_name.count(".") == 1:
schema, table_name = orig_table_name.split(".", 1)
db_name = self.db_creds["database"]
else:
schema = "dbo"
db_name = self.db_creds["database"]
table_name = orig_table_name
cur.execute(
"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ? AND table_schema = ? AND table_catalog = ?;",
table_name,
Expand All @@ -622,7 +622,7 @@ def generate_sqlserver_schema(
rows = [row for row in rows]
rows = [{"column_name": i[0], "data_type": i[1]} for i in rows]
if len(rows) > 0:
schemas[table_name] = rows
schemas[orig_table_name] = rows

conn.close()
if upload:
Expand Down

0 comments on commit 7702c71

Please sign in to comment.