From e4b3ecd3723e95c8b20372f65d837ee546caaa27 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 8 Jan 2025 22:11:28 -0500 Subject: [PATCH] feat: push predicates into virtual datasets (#31486) --- superset/config.py | 2 + superset/connectors/sqla/models.py | 6 +- superset/db_engine_specs/base.py | 2 +- superset/models/core.py | 8 ++ superset/models/helpers.py | 7 +- superset/sql/parse.py | 40 +++++++++ tests/unit_tests/db_engine_specs/test_base.py | 2 +- .../db_engine_specs/test_bigquery.py | 2 +- tests/unit_tests/models/core_test.py | 84 +++++++++++++++++++ tests/unit_tests/sql/parse_tests.py | 43 ++++++++++ 10 files changed, 191 insertions(+), 5 deletions(-) diff --git a/superset/config.py b/superset/config.py index 5d03016298aae..f36ffade3cfed 100644 --- a/superset/config.py +++ b/superset/config.py @@ -517,6 +517,8 @@ class D3TimeFormat(TypedDict, total=False): # Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the # query, and might break queries and/or allow users to bypass RLS. Use with care! "RLS_IN_SQLLAB": False, + # Try to optimize SQL queries — for now only predicate pushdown is supported. + "OPTIMIZE_SQL": False, # When impersonating a user, use the email prefix instead of the username "IMPERSONATE_WITH_EMAIL_PREFIX": False, # Enable caching per impersonation key (e.g username) in a datasource where user diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ac3a75481c02f..bb27678717196 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1568,7 +1568,11 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals # probe adhoc column type tbl, _ = self.get_from_clause(template_processor) qry = sa.select([sqla_column]).limit(1).select_from(tbl) - sql = self.database.compile_sqla_query(qry) + sql = self.database.compile_sqla_query( + qry, + catalog=self.catalog, + schema=self.schema, + ) col_desc = get_columns_description( self.database, self.catalog, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index cd2c318ad86a8..f239ef2019266 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1701,7 +1701,7 @@ def select_star( # pylint: disable=too-many-arguments ) if partition_query is not None: qry = partition_query - sql = database.compile_sqla_query(qry) + sql = database.compile_sqla_query(qry, table.catalog, table.schema) if indent: sql = SQLScript(sql, engine=cls.engine).format() return sql diff --git a/superset/models/core.py b/superset/models/core.py index afabea8f9065a..f71e5c5b6334e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -74,6 +74,7 @@ ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin from superset.result_set import SupersetResultSet +from superset.sql.parse import SQLScript from superset.sql_parse import Table from superset.superset_typing import ( DbapiDescription, @@ -740,6 +741,7 @@ def compile_sqla_query( qry: Select, catalog: str | None = None, schema: str | None = None, + is_virtual: bool = False, ) -> str: with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) @@ -748,6 +750,12 @@ def compile_sqla_query( if engine.dialect.identifier_preparer._double_percents: # noqa sql = sql.replace("%%", "%") + # for nwo we only optimize queries on virtual datasources, since the only + # optimization available is predicate pushdown + if is_feature_enabled("OPTIMIZE_SQL") and is_virtual: + script = SQLScript(sql, self.db_engine_spec.engine).optimize() + sql = script.format() + return sql def select_star( # pylint: disable=too-many-arguments diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 9587ec2385efb..fb6dca6f65012 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -883,7 +883,12 @@ def get_query_str_extended( mutate: bool = True, ) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) - sql = self.database.compile_sqla_query(sqlaq.sqla_query) + sql = self.database.compile_sqla_query( + sqlaq.sqla_query, + catalog=self.catalog, + schema=self.schema, + is_virtual=bool(self.sql), + ) sql = self._apply_cte(sql, sqlaq.cte) if mutate: diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 91d7b51184f21..34ec9299d3600 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -17,6 +17,7 @@ from __future__ import annotations +import copy import enum import logging import re @@ -31,6 +32,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError +from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from superset.exceptions import SupersetParseError @@ -227,6 +229,12 @@ def is_mutating(self) -> bool: """ raise NotImplementedError() + def optimize(self) -> BaseSQLStatement[InternalRepresentation]: + """ + Return optimized statement. + """ + raise NotImplementedError() + def __str__(self) -> str: return self.format() @@ -431,6 +439,19 @@ def get_settings(self) -> dict[str, str | bool]: for eq in set_item.find_all(exp.EQ) } + def optimize(self) -> SQLStatement: + """ + Return optimized statement. + """ + # only optimize statements that have a custom dialect + if not self._dialect: + return SQLStatement(self._sql, self.engine, self._parsed.copy()) + + optimized = pushdown_predicates(self._parsed, dialect=self._dialect) + sql = optimized.sql(dialect=self._dialect) + + return SQLStatement(sql, self.engine, optimized) + class KQLSplitState(enum.Enum): """ @@ -589,6 +610,14 @@ def is_mutating(self) -> bool: """ return self._parsed.startswith(".") and not self._parsed.startswith(".show") + def optimize(self) -> KustoKQLStatement: + """ + Return optimized statement. + + Kusto KQL doesn't support optimization, so this method is a no-op. + """ + return KustoKQLStatement(self._sql, self.engine, self._parsed) + class SQLScript: """ @@ -643,6 +672,17 @@ def has_mutation(self) -> bool: """ return any(statement.is_mutating() for statement in self.statements) + def optimize(self) -> SQLScript: + """ + Return optimized script. + """ + script = copy.deepcopy(self) + script.statements = [ # type: ignore + statement.optimize() for statement in self.statements + ] + + return script + def extract_tables_from_statement( statement: exp.Expression, diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 2644cd6e6f055..bbc3bb0edcefd 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -226,7 +226,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec): # mock the database so we can compile the query database = mocker.MagicMock() - database.compile_sqla_query = lambda query: str( + database.compile_sqla_query = lambda query, catalog, schema: str( query.compile(dialect=sqlite.dialect()) ) diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index 458a6a0393e7d..c28ff0e46b49a 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -149,7 +149,7 @@ def test_select_star(mocker: MockerFixture) -> None: # mock the database so we can compile the query database = mocker.MagicMock() - database.compile_sqla_query = lambda query: str( + database.compile_sqla_query = lambda query, catalog, schema: str( query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True}) ) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 5bc3c86af657a..36ce618f887a5 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -21,9 +21,17 @@ import pytest from pytest_mock import MockerFixture +from sqlalchemy import ( + Column, + Integer, + MetaData, + select, + Table as SqlalchemyTable, +) from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import make_url from sqlalchemy.orm.session import Session +from sqlalchemy.sql import Select from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.errors import SupersetErrorType @@ -45,6 +53,29 @@ } +@pytest.fixture +def query() -> Select: + """ + A nested query fixture used to test query optimization. + """ + metadata = MetaData() + some_table = SqlalchemyTable( + "some_table", + metadata, + Column("a", Integer), + Column("b", Integer), + Column("c", Integer), + ) + + inner_select = select(some_table.c.a, some_table.c.b, some_table.c.c) + outer_select = select(inner_select.c.a, inner_select.c.b).where( + inner_select.c.a > 1, + inner_select.c.b == 2, + ) + + return outer_select + + def test_get_metrics(mocker: MockerFixture) -> None: """ Tests for ``get_metrics``. @@ -683,3 +714,56 @@ def test_purge_oauth2_tokens(session: Session) -> None: # make sure database was not deleted... just in case database = session.query(Database).filter_by(id=database1.id).one() assert database.name == "my_oauth2_db" + + +def test_compile_sqla_query_no_optimization(query: Select) -> None: + """ + Test the `compile_sqla_query` method. + """ + from superset.models.core import Database + + database = Database( + database_name="db", + sqlalchemy_uri="sqlite://", + ) + + space = " " + + assert ( + database.compile_sqla_query(query, is_virtual=True) + == f"""SELECT anon_1.a, anon_1.b{space} +FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c{space} +FROM some_table) AS anon_1{space} +WHERE anon_1.a > 1 AND anon_1.b = 2""" + ) + + +@with_feature_flags(OPTIMIZE_SQL=True) +def test_compile_sqla_query(query: Select) -> None: + """ + Test the `compile_sqla_query` method. + """ + from superset.models.core import Database + + database = Database( + database_name="db", + sqlalchemy_uri="sqlite://", + ) + + assert ( + database.compile_sqla_query(query, is_virtual=True) + == """SELECT + anon_1.a, + anon_1.b +FROM ( + SELECT + some_table.a AS a, + some_table.b AS b, + some_table.c AS c + FROM some_table + WHERE + some_table.a > 1 AND some_table.b = 2 +) AS anon_1 +WHERE + TRUE AND TRUE""" + ) diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 5103ef12eccf2..1eabb78e05d95 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -1070,3 +1070,46 @@ def test_is_mutating(engine: str) -> None: "with source as ( select 1 as one ) select * from source", engine=engine, ).is_mutating() + + +def test_optimize() -> None: + """ + Test that the `optimize` method works as expected. + + The SQL optimization only works with engines that have a corresponding dialect. + """ + sql = """ +SELECT anon_1.a, anon_1.b +FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c +FROM some_table) AS anon_1 +WHERE anon_1.a > 1 AND anon_1.b = 2 + """ + + optimized = """SELECT + anon_1.a, + anon_1.b +FROM ( + SELECT + some_table.a AS a, + some_table.b AS b, + some_table.c AS c + FROM some_table + WHERE + some_table.a > 1 AND some_table.b = 2 +) AS anon_1 +WHERE + TRUE AND TRUE""" + + not_optimized = """ +SELECT anon_1.a, + anon_1.b +FROM + (SELECT some_table.a AS a, + some_table.b AS b, + some_table.c AS c + FROM some_table) AS anon_1 +WHERE anon_1.a > 1 + AND anon_1.b = 2""" + + assert SQLStatement(sql, "sqlite").optimize().format() == optimized + assert SQLStatement(sql, "firebolt").optimize().format() == not_optimized