Skip to content

Commit

Permalink
feat: push predicates into virtual datasets (#31486)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Jan 9, 2025
1 parent f29eafd commit e4b3ecd
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 5 deletions.
2 changes: 2 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ class D3TimeFormat(TypedDict, total=False):
# Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the
# query, and might break queries and/or allow users to bypass RLS. Use with care!
"RLS_IN_SQLLAB": False,
# Try to optimize SQL queries — for now only predicate pushdown is supported.
"OPTIMIZE_SQL": False,
# When impersonating a user, use the email prefix instead of the username
"IMPERSONATE_WITH_EMAIL_PREFIX": False,
# Enable caching per impersonation key (e.g username) in a datasource where user
Expand Down
6 changes: 5 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,11 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
# probe adhoc column type
tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
sql = self.database.compile_sqla_query(
qry,
catalog=self.catalog,
schema=self.schema,
)
col_desc = get_columns_description(
self.database,
self.catalog,
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,7 @@ def select_star( # pylint: disable=too-many-arguments
)
if partition_query is not None:
qry = partition_query
sql = database.compile_sqla_query(qry)
sql = database.compile_sqla_query(qry, table.catalog, table.schema)
if indent:
sql = SQLScript(sql, engine=cls.engine).format()
return sql
Expand Down
8 changes: 8 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
from superset.result_set import SupersetResultSet
from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import (
DbapiDescription,
Expand Down Expand Up @@ -740,6 +741,7 @@ def compile_sqla_query(
qry: Select,
catalog: str | None = None,
schema: str | None = None,
is_virtual: bool = False,
) -> str:
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
Expand All @@ -748,6 +750,12 @@ def compile_sqla_query(
if engine.dialect.identifier_preparer._double_percents: # noqa
sql = sql.replace("%%", "%")

# for nwo we only optimize queries on virtual datasources, since the only
# optimization available is predicate pushdown
if is_feature_enabled("OPTIMIZE_SQL") and is_virtual:
script = SQLScript(sql, self.db_engine_spec.engine).optimize()
sql = script.format()

return sql

def select_star( # pylint: disable=too-many-arguments
Expand Down
7 changes: 6 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,12 @@ def get_query_str_extended(
mutate: bool = True,
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self.database.compile_sqla_query(
sqlaq.sqla_query,
catalog=self.catalog,
schema=self.schema,
is_virtual=bool(self.sql),
)
sql = self._apply_cte(sql, sqlaq.cte)

if mutate:
Expand Down
40 changes: 40 additions & 0 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import copy
import enum
import logging
import re
Expand All @@ -31,6 +32,7 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope

from superset.exceptions import SupersetParseError
Expand Down Expand Up @@ -227,6 +229,12 @@ def is_mutating(self) -> bool:
"""
raise NotImplementedError()

def optimize(self) -> BaseSQLStatement[InternalRepresentation]:
"""
Return optimized statement.
"""
raise NotImplementedError()

def __str__(self) -> str:
return self.format()

Expand Down Expand Up @@ -431,6 +439,19 @@ def get_settings(self) -> dict[str, str | bool]:
for eq in set_item.find_all(exp.EQ)
}

def optimize(self) -> SQLStatement:
"""
Return optimized statement.
"""
# only optimize statements that have a custom dialect
if not self._dialect:
return SQLStatement(self._sql, self.engine, self._parsed.copy())

optimized = pushdown_predicates(self._parsed, dialect=self._dialect)
sql = optimized.sql(dialect=self._dialect)

return SQLStatement(sql, self.engine, optimized)


class KQLSplitState(enum.Enum):
"""
Expand Down Expand Up @@ -589,6 +610,14 @@ def is_mutating(self) -> bool:
"""
return self._parsed.startswith(".") and not self._parsed.startswith(".show")

def optimize(self) -> KustoKQLStatement:
"""
Return optimized statement.
Kusto KQL doesn't support optimization, so this method is a no-op.
"""
return KustoKQLStatement(self._sql, self.engine, self._parsed)


class SQLScript:
"""
Expand Down Expand Up @@ -643,6 +672,17 @@ def has_mutation(self) -> bool:
"""
return any(statement.is_mutating() for statement in self.statements)

def optimize(self) -> SQLScript:
"""
Return optimized script.
"""
script = copy.deepcopy(self)
script.statements = [ # type: ignore
statement.optimize() for statement in self.statements
]

return script


def extract_tables_from_statement(
statement: exp.Expression,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec):

# mock the database so we can compile the query
database = mocker.MagicMock()
database.compile_sqla_query = lambda query: str(
database.compile_sqla_query = lambda query, catalog, schema: str(
query.compile(dialect=sqlite.dialect())
)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/db_engine_specs/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_select_star(mocker: MockerFixture) -> None:

# mock the database so we can compile the query
database = mocker.MagicMock()
database.compile_sqla_query = lambda query: str(
database.compile_sqla_query = lambda query, catalog, schema: str(
query.compile(dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True})
)

Expand Down
84 changes: 84 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@

import pytest
from pytest_mock import MockerFixture
from sqlalchemy import (
Column,
Integer,
MetaData,
select,
Table as SqlalchemyTable,
)
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Select

from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
Expand All @@ -45,6 +53,29 @@
}


@pytest.fixture
def query() -> Select:
"""
A nested query fixture used to test query optimization.
"""
metadata = MetaData()
some_table = SqlalchemyTable(
"some_table",
metadata,
Column("a", Integer),
Column("b", Integer),
Column("c", Integer),
)

inner_select = select(some_table.c.a, some_table.c.b, some_table.c.c)
outer_select = select(inner_select.c.a, inner_select.c.b).where(
inner_select.c.a > 1,
inner_select.c.b == 2,
)

return outer_select


def test_get_metrics(mocker: MockerFixture) -> None:
"""
Tests for ``get_metrics``.
Expand Down Expand Up @@ -683,3 +714,56 @@ def test_purge_oauth2_tokens(session: Session) -> None:
# make sure database was not deleted... just in case
database = session.query(Database).filter_by(id=database1.id).one()
assert database.name == "my_oauth2_db"


def test_compile_sqla_query_no_optimization(query: Select) -> None:
"""
Test the `compile_sqla_query` method.
"""
from superset.models.core import Database

database = Database(
database_name="db",
sqlalchemy_uri="sqlite://",
)

space = " "

assert (
database.compile_sqla_query(query, is_virtual=True)
== f"""SELECT anon_1.a, anon_1.b{space}
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c{space}
FROM some_table) AS anon_1{space}
WHERE anon_1.a > 1 AND anon_1.b = 2"""
)


@with_feature_flags(OPTIMIZE_SQL=True)
def test_compile_sqla_query(query: Select) -> None:
"""
Test the `compile_sqla_query` method.
"""
from superset.models.core import Database

database = Database(
database_name="db",
sqlalchemy_uri="sqlite://",
)

assert (
database.compile_sqla_query(query, is_virtual=True)
== """SELECT
anon_1.a,
anon_1.b
FROM (
SELECT
some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table
WHERE
some_table.a > 1 AND some_table.b = 2
) AS anon_1
WHERE
TRUE AND TRUE"""
)
43 changes: 43 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,3 +1070,46 @@ def test_is_mutating(engine: str) -> None:
"with source as ( select 1 as one ) select * from source",
engine=engine,
).is_mutating()


def test_optimize() -> None:
"""
Test that the `optimize` method works as expected.
The SQL optimization only works with engines that have a corresponding dialect.
"""
sql = """
SELECT anon_1.a, anon_1.b
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c
FROM some_table) AS anon_1
WHERE anon_1.a > 1 AND anon_1.b = 2
"""

optimized = """SELECT
anon_1.a,
anon_1.b
FROM (
SELECT
some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table
WHERE
some_table.a > 1 AND some_table.b = 2
) AS anon_1
WHERE
TRUE AND TRUE"""

not_optimized = """
SELECT anon_1.a,
anon_1.b
FROM
(SELECT some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table) AS anon_1
WHERE anon_1.a > 1
AND anon_1.b = 2"""

assert SQLStatement(sql, "sqlite").optimize().format() == optimized
assert SQLStatement(sql, "firebolt").optimize().format() == not_optimized

0 comments on commit e4b3ecd

Please sign in to comment.