Skip to content

Commit

Permalink
refactor: remove more sqlparse (apache#31032)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Nov 26, 2024
1 parent 9224051 commit 09802ac
Show file tree
Hide file tree
Showing 15 changed files with 95 additions and 172 deletions.
7 changes: 2 additions & 5 deletions superset/commands/dataset/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from superset.exceptions import SupersetErrorException
from superset.extensions import db
from superset.models.core import Database
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,10 +70,7 @@ def run(self) -> Model:
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
table.sql = self._base_model.sql.strip().strip(";")
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
Expand Down
2 changes: 1 addition & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# def DATASET_HEALTH_CHECK(datasource: SqlaTable) -> Optional[str]:
# if (
# datasource.sql and
# len(sql_parse.ParsedQuery(datasource.sql, strip_comments=True).tables) == 1
# len(SQLScript(datasource.sql).tables) == 1
# ):
# return (
# "This virtual dataset queries only one table and therefore could be "
Expand Down
31 changes: 5 additions & 26 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.expression import Label
from sqlalchemy.sql.selectable import Alias, TableClause

from superset import app, db, is_feature_enabled, security_manager
Expand Down Expand Up @@ -104,7 +104,7 @@
QueryResult,
)
from superset.models.slice import Slice
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
Expand Down Expand Up @@ -1469,34 +1469,13 @@ def get_sqla_table(self) -> TableClause:
return tbl

def get_from_clause(
self, template_processor: BaseTemplateProcessor | None = None
self,
template_processor: BaseTemplateProcessor | None = None,
) -> tuple[TableClause | Alias, str | None]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
if not self.is_virtual:
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
):
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)

cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
table(self.db_engine_spec.cte_alias)
if cte
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
)

return from_clause, cte
return super().get_from_clause(template_processor)

def adhoc_metric_to_sqla(
self,
Expand Down
29 changes: 10 additions & 19 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql.parse import SQLScript, Table
from superset.sql.parse import BaseSQLStatement, SQLScript, Table
from superset.sql_parse import ParsedQuery
from superset.superset_typing import (
OAuth2ClientConfig,
Expand Down Expand Up @@ -1737,18 +1737,19 @@ def query_cost_formatter(
)

@classmethod
def process_statement(cls, statement: str, database: Database) -> str:
def process_statement(
cls,
statement: BaseSQLStatement[Any],
database: Database,
) -> str:
"""
Process a SQL statement by stripping and mutating it.
Process a SQL statement by mutating it.
:param statement: A single SQL statement
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()

return database.mutate_sql_based_on_config(sql, is_split=True)
return database.mutate_sql_based_on_config(str(statement), is_split=True)

@classmethod
def estimate_query_cost( # pylint: disable=too-many-arguments
Expand All @@ -1773,8 +1774,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
"Database does not support cost estimation"
)

parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=cls.engine)

with database.get_raw_connection(
catalog=catalog,
Expand All @@ -1788,7 +1788,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
cls.process_statement(statement, database),
cursor,
)
for statement in statements
for statement in parsed_script.statements
]

@classmethod
Expand Down Expand Up @@ -2056,15 +2056,6 @@ def update_params_from_encrypted_extra( # pylint: disable=invalid-name
logger.error(ex, exc_info=True)
raise

@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return (
parsed_query.is_select()
or parsed_query.is_explain()
or parsed_query.is_show()
)

@classmethod
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
"""
Expand Down
7 changes: 3 additions & 4 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes

from superset import sql_parse
from superset.constants import TimeGrain
from superset.databases.schemas import encrypted_field_properties, EncryptedString
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.errors import SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils, json
Expand Down Expand Up @@ -449,8 +449,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=cls.engine)

with cls.get_engine(
database,
Expand All @@ -463,7 +462,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
cls.process_statement(statement, database),
client,
)
for statement in statements
for statement in parsed_script.statements
]

@classmethod
Expand Down
11 changes: 1 addition & 10 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from superset.exceptions import SupersetException
from superset.extensions import cache_manager
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType

if TYPE_CHECKING:
Expand Down Expand Up @@ -598,15 +598,6 @@ def get_function_names(cls, database: Database) -> list[str]:
# otherwise, return no function names to prevent errors
return []

@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return (
super().is_readonly_query(parsed_query)
or parsed_query.is_set()
or parsed_query.is_show()
)

@classmethod
def has_implicit_cancel(cls) -> bool:
"""
Expand Down
14 changes: 0 additions & 14 deletions superset/db_engine_specs/kusto.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ def convert_dttm(
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
return None

@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return parsed_query.sql.lower().startswith("select")


class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
limit_method = LimitMethod.WRAP_SQL
Expand Down Expand Up @@ -158,15 +153,6 @@ def convert_dttm(

return None

@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""
Pessimistic readonly, 100% sure statement won't mutate anything.
"""
return KustoKqlEngineSpec.is_select_query(
parsed_query
) or parsed_query.sql.startswith(".show")

@classmethod
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
return not parsed_query.sql.startswith(".")
Expand Down
12 changes: 5 additions & 7 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from superset.sql_parse import (
has_table_query,
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
)
from superset.superset_typing import (
Expand Down Expand Up @@ -1039,6 +1038,9 @@ def get_rendered_sql(
"""
Render sql with template engine (Jinja).
"""
if not self.sql:
return ""

sql = self.sql.strip("\t\r\n; ")
if template_processor:
try:
Expand Down Expand Up @@ -1072,13 +1074,9 @@ def get_from_clause(
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""

from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
):
parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine)
if parsed_script.has_mutation():
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
Expand Down
23 changes: 12 additions & 11 deletions superset/sql_validators/presto_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import logging
import time
from contextlib import closing
from typing import Any
from typing import Any, cast

from superset import app
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql.parse import SQLScript, SQLStatement
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import QuerySource

Expand All @@ -46,17 +46,15 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate_statement(
cls,
statement: str,
statement: SQLStatement,
database: Database,
cursor: Any,
) -> SQLValidationAnnotation | None:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped()

# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = database.mutate_sql_based_on_config(sql)
sql = database.mutate_sql_based_on_config(str(statement))

# Transform the final statement to an explain call before sending it on
# to presto to validate
Expand Down Expand Up @@ -155,10 +153,9 @@ def validate(
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=database.db_engine_spec.engine)

logger.info("Validating %i statement(s)", len(statements))
logger.info("Validating %i statement(s)", len(parsed_script.statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
with database.get_sqla_engine(
Expand All @@ -171,8 +168,12 @@ def validate(
annotations: list[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(statement, database, cursor)
for statement in parsed_script.statements:
annotation = cls.validate_statement(
cast(SQLStatement, statement),
database,
cursor,
)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))
Expand Down
8 changes: 2 additions & 6 deletions superset/sqllab/query_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from superset import is_feature_enabled
from superset.commands.sql_lab.execute import SqlQueryRender
from superset.errors import SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.sqllab.exceptions import SqlLabException
from superset.utils import core as utils

Expand Down Expand Up @@ -58,12 +57,9 @@ def render(self, execution_context: SqlJsonExecutionContext) -> str:
database=query_model.database, query=query_model
)

parsed_query = ParsedQuery(
query_model.sql,
engine=query_model.database.db_engine_spec.engine,
)
rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params
query_model.sql.strip().strip(";"),
**execution_context.template_params,
)
self._validate(execution_context, rendered_query, sql_template_processor)
return rendered_query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
Expand Down Expand Up @@ -310,20 +310,6 @@ def test_calculated_column_in_order_by_base_engine_spec(self):
)


def test_is_readonly():
def is_readonly(sql: str) -> bool:
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))

assert is_readonly("SHOW LOCKS test EXTENDED")
assert not is_readonly("SET hivevar:desc='Legislators'")
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
assert is_readonly("SHOW CATALOGS")
assert is_readonly("SHOW TABLES")


def test_time_grain_denylist():
config = app.config.copy()
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M", "SQLITE_NONEXISTENT_GRAIN"]
Expand Down
Loading

0 comments on commit 09802ac

Please sign in to comment.