Skip to content

Commit

Permalink
chore: wrap http client to init in lifespan
Browse files Browse the repository at this point in the history
  • Loading branch information
grieve54706 committed Dec 16, 2024
1 parent f3b88a0 commit 22cc17c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 52 deletions.
8 changes: 5 additions & 3 deletions ibis-server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.responses import PlainTextResponse

from app.config import get_config
from app.mdl.http import get_http_client, warmup_http_client
from app.mdl.java_engine import get_java_engine_connector
from app.middleware import ProcessTimeMiddleware, RequestLogMiddleware
from app.model import ConfigModel, CustomHttpError
from app.routers import v2, v3
Expand All @@ -19,11 +19,13 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
asyncio.create_task(warmup_http_client()) # noqa: RUF006
java_engine_connector = get_java_engine_connector()
java_engine_connector.start()
asyncio.create_task(java_engine_connector.warmup()) # noqa: RUF006

yield

await get_http_client().aclose()
await java_engine_connector.close()


app = FastAPI(lifespan=lifespan)
Expand Down
35 changes: 0 additions & 35 deletions ibis-server/app/mdl/http.py

This file was deleted.

60 changes: 60 additions & 0 deletions ibis-server/app/mdl/java_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import anyio
import httpcore
import httpx
from orjson import orjson

from app.config import get_config

wren_engine_endpoint = get_config().wren_engine_endpoint


class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]


class JavaEngineConnector(metaclass=Singleton):
def __init__(self):
self.client = None

def start(self):
self.client = httpx.AsyncClient(
base_url=wren_engine_endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
)

async def dry_plan(self, manifest_str: str, sql: str):
r = await self.client.request(
method="GET",
url="/v2/mdl/dry-plan",
content=orjson.dumps({"manifestStr": manifest_str, "sql": sql}),
)
return r.raise_for_status().text.replace("\n", " ")

async def warmup(self, timeout=30):
for _ in range(timeout):
try:
response = await self.client.get("/v2/health")
if response.status_code == 200:
return
except (
httpx.ConnectError,
httpx.HTTPStatusError,
httpx.TimeoutException,
httpcore.ReadTimeout,
):
await anyio.sleep(1)

async def close(self):
await self.client.aclose()


def get_java_engine_connector() -> JavaEngineConnector:
return JavaEngineConnector()
20 changes: 6 additions & 14 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import importlib

import httpx
import orjson
import sqlglot
from anyio import to_thread
from loguru import logger
Expand All @@ -12,26 +11,23 @@
get_session_context,
to_json_base64,
)
from app.mdl.http import get_http_client
from app.mdl.java_engine import JavaEngineConnector
from app.model import InternalServerError, UnprocessableEntityError
from app.model.data_source import DataSource

wren_engine_endpoint = get_config().wren_engine_endpoint

# To register custom dialects from ibis library for sqlglot
importlib.import_module("ibis.backends.sql.dialects")

# Register custom dialects
importlib.import_module("app.custom_sqlglot.dialects")

client = get_http_client()


class Rewriter:
def __init__(
self,
manifest_str: str,
data_source: DataSource = None,
java_engine_connector=None,
experiment=False,
):
self.manifest_str = manifest_str
Expand All @@ -42,7 +38,7 @@ def __init__(
function_path = config.get_remote_function_list_path(data_source)
self._rewriter = EmbeddedEngineRewriter(manifest_str, function_path)
else:
self._rewriter = ExternalEngineRewriter(manifest_str)
self._rewriter = ExternalEngineRewriter(manifest_str, java_engine_connector)

async def rewrite(self, sql: str) -> str:
planned_sql = await self._rewriter.rewrite(sql)
Expand All @@ -67,21 +63,17 @@ def _get_write_dialect(cls, data_source: DataSource) -> str:


class ExternalEngineRewriter:
def __init__(self, manifest_str: str):
def __init__(self, manifest_str: str, java_engine_connector: JavaEngineConnector):
self.manifest_str = manifest_str
self.java_engine_connector = java_engine_connector

async def rewrite(self, sql: str) -> str:
try:
extractor = get_manifest_extractor(self.manifest_str)
tables = extractor.resolve_used_table_names(sql)
manifest = extractor.extract_by(tables)
manifest_str = to_json_base64(manifest)
r = await client.request(
method="GET",
url=f"{wren_engine_endpoint}/v2/mdl/dry-plan",
content=orjson.dumps({"manifestStr": manifest_str, "sql": sql}),
)
return r.raise_for_status().text.replace("\n", " ")
return await self.java_engine_connector.dry_plan(manifest_str, sql)
except httpx.ConnectError as e:
raise WrenEngineError(f"Can not connect to Java Engine: {e}")
except httpx.TimeoutException as e:
Expand Down

0 comments on commit 22cc17c

Please sign in to comment.