diff --git a/ibis-server/app/main.py b/ibis-server/app/main.py index 56c42a7b0..dc8c38d49 100644 --- a/ibis-server/app/main.py +++ b/ibis-server/app/main.py @@ -9,7 +9,7 @@ from starlette.responses import PlainTextResponse from app.config import get_config -from app.mdl.http import get_http_client, warmup_http_client +from app.mdl.java_engine import get_java_engine_connector from app.middleware import ProcessTimeMiddleware, RequestLogMiddleware from app.model import ConfigModel, CustomHttpError from app.routers import v2, v3 @@ -19,11 +19,13 @@ @asynccontextmanager async def lifespan(app: FastAPI): - asyncio.create_task(warmup_http_client()) # noqa: RUF006 + java_engine_connector = get_java_engine_connector() + java_engine_connector.start() + asyncio.create_task(java_engine_connector.warmup()) # noqa: RUF006 yield - await get_http_client().aclose() + await java_engine_connector.close() app = FastAPI(lifespan=lifespan) diff --git a/ibis-server/app/mdl/http.py b/ibis-server/app/mdl/http.py deleted file mode 100644 index d68ffdb70..000000000 --- a/ibis-server/app/mdl/http.py +++ /dev/null @@ -1,35 +0,0 @@ -import anyio -import httpcore -import httpx - -from app.config import get_config - -wren_engine_endpoint = get_config().wren_engine_endpoint - - -client = httpx.AsyncClient( - base_url=wren_engine_endpoint, - headers={ - "Content-Type": "application/json", - "Accept": "application/json", - }, -) - - -async def warmup_http_client(timeout=30): - for _ in range(timeout): - try: - response = await client.get("/v2/health") - if response.status_code == 200: - return - except ( - httpx.ConnectError, - httpx.HTTPStatusError, - httpx.TimeoutException, - httpcore.ReadTimeout, - ): - await anyio.sleep(1) - - -def get_http_client() -> httpx.AsyncClient: - return client diff --git a/ibis-server/app/mdl/java_engine.py b/ibis-server/app/mdl/java_engine.py new file mode 100644 index 000000000..851f010d5 --- /dev/null +++ b/ibis-server/app/mdl/java_engine.py @@ -0,0 +1,60 @@ +import anyio +import httpcore +import httpx +from orjson import orjson + +from app.config import get_config + +wren_engine_endpoint = get_config().wren_engine_endpoint + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +class JavaEngineConnector(metaclass=Singleton): + def __init__(self): + self.client = None + + def start(self): + self.client = httpx.AsyncClient( + base_url=wren_engine_endpoint, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + ) + + async def dry_plan(self, manifest_str: str, sql: str): + r = await self.client.request( + method="GET", + url="/v2/mdl/dry-plan", + content=orjson.dumps({"manifestStr": manifest_str, "sql": sql}), + ) + return r.raise_for_status().text.replace("\n", " ") + + async def warmup(self, timeout=30): + for _ in range(timeout): + try: + response = await self.client.get("/v2/health") + if response.status_code == 200: + return + except ( + httpx.ConnectError, + httpx.HTTPStatusError, + httpx.TimeoutException, + httpcore.ReadTimeout, + ): + await anyio.sleep(1) + + async def close(self): + await self.client.aclose() + + +def get_java_engine_connector() -> JavaEngineConnector: + return JavaEngineConnector() diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index 1c3db01a7..cbaf1acaa 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -1,7 +1,6 @@ import importlib import httpx -import orjson import sqlglot from anyio import to_thread from loguru import logger @@ -12,26 +11,23 @@ get_session_context, to_json_base64, ) -from app.mdl.http import get_http_client +from app.mdl.java_engine import JavaEngineConnector from app.model import InternalServerError, UnprocessableEntityError from app.model.data_source import DataSource -wren_engine_endpoint = get_config().wren_engine_endpoint - # To register custom dialects from ibis library for sqlglot importlib.import_module("ibis.backends.sql.dialects") # Register custom dialects importlib.import_module("app.custom_sqlglot.dialects") -client = get_http_client() - class Rewriter: def __init__( self, manifest_str: str, data_source: DataSource = None, + java_engine_connector=None, experiment=False, ): self.manifest_str = manifest_str @@ -42,7 +38,7 @@ def __init__( function_path = config.get_remote_function_list_path(data_source) self._rewriter = EmbeddedEngineRewriter(manifest_str, function_path) else: - self._rewriter = ExternalEngineRewriter(manifest_str) + self._rewriter = ExternalEngineRewriter(manifest_str, java_engine_connector) async def rewrite(self, sql: str) -> str: planned_sql = await self._rewriter.rewrite(sql) @@ -67,8 +63,9 @@ def _get_write_dialect(cls, data_source: DataSource) -> str: class ExternalEngineRewriter: - def __init__(self, manifest_str: str): + def __init__(self, manifest_str: str, java_engine_connector: JavaEngineConnector): self.manifest_str = manifest_str + self.java_engine_connector = java_engine_connector async def rewrite(self, sql: str) -> str: try: @@ -76,12 +73,7 @@ async def rewrite(self, sql: str) -> str: tables = extractor.resolve_used_table_names(sql) manifest = extractor.extract_by(tables) manifest_str = to_json_base64(manifest) - r = await client.request( - method="GET", - url=f"{wren_engine_endpoint}/v2/mdl/dry-plan", - content=orjson.dumps({"manifestStr": manifest_str, "sql": sql}), - ) - return r.raise_for_status().text.replace("\n", " ") + return await self.java_engine_connector.dry_plan(manifest_str, sql) except httpx.ConnectError as e: raise WrenEngineError(f"Can not connect to Java Engine: {e}") except httpx.TimeoutException as e: