Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strftime to SQL features #4

Merged
merged 4 commits into from
Jun 7, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion defog_utils/utils_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class SqlFeatures(Features):
has_date_int: bool = False
date_trunc: bool = False
date_part: bool = False
strftime: bool = False
current_date_time: bool = False
interval: bool = False
date_time_type_conversion: bool = False
Expand Down Expand Up @@ -107,6 +108,7 @@ class SqlFeatures(Features):
]
current_date_time_expressions = [
exp.CurrentDate,
exp.CurrentDatetime,
exp.CurrentTime,
exp.CurrentTimestamp,
]
Expand Down Expand Up @@ -262,6 +264,7 @@ def get_sql_features(
md_cols: Optional[Set[str]] = None,
md_tables: Optional[Set[str]] = None,
extra_column_info: Optional[Dict[str, str]] = None,
dialect: str = "postgres",
) -> SqlFeatures:
"""
Extracts features from a SQL query string by making a single pass through the parsed SQL abstract syntax tree (AST).
Expand Down Expand Up @@ -293,7 +296,7 @@ def get_sql_features(
if " ~ '" in sql:
sql = sql.replace(" ~ '", " LIKE '")
sql = re.sub("character varying", "varchar", sql, flags=re.IGNORECASE)
parsed = parse_one(sql)
parsed = parse_one(sql, dialect)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for this - thanks!

# internal state for computing various summarized/derived quantities
columns_in_sql = set()
tables_in_sql = set()
Expand Down Expand Up @@ -385,6 +388,8 @@ def get_sql_features(
features.rank = True
elif isinstance(node, exp.DateTrunc):
features.date_trunc = True
elif isinstance(node, exp.TimeToStr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL!

features.strftime = True
elif isinstance(node, exp.Extract):
features.date_part = True
elif type(node) in current_date_time_expressions:
Expand Down Expand Up @@ -444,10 +449,16 @@ def get_sql_features(
features.rank = True
elif node_name == "date_part":
features.date_part = True
elif node_name == "strftime":
features.strftime = True
elif node_name == "now":
features.current_date_time = True
elif node_name == "to_date" or node_name == "to_timestamp":
features.date_time_type_conversion = True
# other non-defined non-Anonymous expressions in sqlglot
else:
if "now" in str(node).lower():
features.current_date_time = True

if md_cols and md_tables:
md_cols = set([c.lower() for c in md_cols])
Expand Down
Loading