Skip to content

Commit

Permalink
chore(rewriter): pull out duplicate code (#1001)
Browse files Browse the repository at this point in the history
  • Loading branch information
grieve54706 authored Dec 23, 2024
1 parent d32124e commit 662d75b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
56 changes: 29 additions & 27 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,45 @@ def __init__(
self,
manifest_str: str,
data_source: DataSource = None,
java_engine_connector=None,
java_engine_connector: JavaEngineConnector = None,
experiment=False,
):
self.manifest_str = manifest_str
self.data_source = data_source
self.experiment = experiment
if experiment:
config = get_config()
function_path = config.get_remote_function_list_path(data_source)
self._rewriter = EmbeddedEngineRewriter(manifest_str, function_path)
function_path = get_config().get_remote_function_list_path(data_source)
self._rewriter = EmbeddedEngineRewriter(function_path)
else:
self._rewriter = ExternalEngineRewriter(manifest_str, java_engine_connector)
self._rewriter = ExternalEngineRewriter(java_engine_connector)

def _transpile(self, planned_sql: str) -> str:
read = self._get_read_dialect(self.experiment)
write = self._get_write_dialect(self.data_source)
return sqlglot.transpile(planned_sql, read=read, write=write)[0]

async def rewrite(self, sql: str) -> str:
planned_sql = await self._rewriter.rewrite(sql)
manifest_str = self._extract_manifest(self.manifest_str, sql)
logger.debug("Extracted manifest: {}", manifest_str)
planned_sql = await self._rewriter.rewrite(manifest_str, sql)
logger.debug("Planned SQL: {}", planned_sql)
dialect_sql = self._transpile(planned_sql) if self.data_source else planned_sql
logger.debug("Dialect SQL: {}", dialect_sql)
return dialect_sql

def _transpile(self, planned_sql: str) -> str:
write = self._get_write_dialect(self.data_source)
if self.experiment:
read = None
else:
read = "trino"
return sqlglot.transpile(planned_sql, read=read, write=write)[0]
@staticmethod
def _extract_manifest(manifest_str: str, sql: str) -> str:
try:
extractor = get_manifest_extractor(manifest_str)
tables = extractor.resolve_used_table_names(sql)
manifest = extractor.extract_by(tables)
return to_json_base64(manifest)
except Exception as e:
raise RewriteError(str(e))

@classmethod
def _get_read_dialect(cls, experiment) -> str | None:
return None if experiment else "trino"

@classmethod
def _get_write_dialect(cls, data_source: DataSource) -> str:
Expand All @@ -63,16 +75,11 @@ def _get_write_dialect(cls, data_source: DataSource) -> str:


class ExternalEngineRewriter:
def __init__(self, manifest_str: str, java_engine_connector: JavaEngineConnector):
self.manifest_str = manifest_str
def __init__(self, java_engine_connector: JavaEngineConnector):
self.java_engine_connector = java_engine_connector

async def rewrite(self, sql: str) -> str:
async def rewrite(self, manifest_str: str, sql: str) -> str:
try:
extractor = get_manifest_extractor(self.manifest_str)
tables = extractor.resolve_used_table_names(sql)
manifest = extractor.extract_by(tables)
manifest_str = to_json_base64(manifest)
return await self.java_engine_connector.dry_plan(manifest_str, sql)
except httpx.ConnectError as e:
raise WrenEngineError(f"Can not connect to Java Engine: {e}")
Expand All @@ -83,16 +90,11 @@ async def rewrite(self, sql: str) -> str:


class EmbeddedEngineRewriter:
def __init__(self, manifest_str: str, function_path: str):
self.manifest_str = manifest_str
def __init__(self, function_path: str):
self.function_path = function_path

async def rewrite(self, sql: str) -> str:
async def rewrite(self, manifest_str: str, sql: str) -> str:
try:
extractor = get_manifest_extractor(self.manifest_str)
tables = extractor.resolve_used_table_names(sql)
manifest = extractor.extract_by(tables)
manifest_str = to_json_base64(manifest)
session_context = get_session_context(manifest_str, self.function_path)
return await to_thread.run_sync(session_context.transform_sql, sql)
except Exception as e:
Expand Down
16 changes: 16 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ async def test_query_with_limit(client, manifest_str, postgres: PostgresContaine
assert len(result["data"]) == 1


async def test_query_with_invalid_manifest_str(
client, manifest_str, postgres: PostgresContainer
):
connection_info = _to_connection_info(postgres)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": "xxx",
"sql": 'SELECT * FROM "Orders" LIMIT 1',
},
)
assert response.status_code == 422
assert response.text == "Base64 decode error: Invalid padding"


async def test_query_without_manifest(client, postgres: PostgresContainer):
connection_info = _to_connection_info(postgres)
response = await client.post(
Expand Down

0 comments on commit 662d75b

Please sign in to comment.