From 5e4f0b94a05a39c3cfdc808a9b2d18e7da624c3f Mon Sep 17 00:00:00 2001 From: Sigurd Pettersen Date: Fri, 22 Nov 2024 11:11:50 +0100 Subject: [PATCH] Implement async lifespan for FastAPI and refactor HTTP client usage in surface query service --- backend_py/primary/primary/main.py | 23 +++++++++++++++++++ .../primary/primary/routers/surface/router.py | 6 ++++- .../surface_query_service.py | 18 ++++++++++----- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/backend_py/primary/primary/main.py b/backend_py/primary/primary/main.py index 6f12bf57a..477d3af50 100644 --- a/backend_py/primary/primary/main.py +++ b/backend_py/primary/primary/main.py @@ -57,7 +57,30 @@ def custom_generate_unique_id(route: APIRoute) -> str: return f"{route.name}" + + + +import httpx +from contextlib import asynccontextmanager + +@asynccontextmanager +async def app_lifespan(app: FastAPI): + print("start lifespan") + #limits = httpx.Limits(max_keepalive_connections=200, max_connections=300) + + #app.state.requests_client = httpx.AsyncClient(http2=True, verify=False, limits=limits) + #app.state.requests_client = httpx.AsyncClient(verify=False) + app.state.requests_client = httpx.AsyncClient() + yield + + await app.state.requests_client.aclose() + print("end lifespan") + + + + app = FastAPI( + lifespan=app_lifespan, generate_unique_id_function=custom_generate_unique_id, root_path="/api", default_response_class=ORJSONResponse, diff --git a/backend_py/primary/primary/routers/surface/router.py b/backend_py/primary/primary/routers/surface/router.py index 02d74bab0..755eec1f6 100644 --- a/backend_py/primary/primary/routers/surface/router.py +++ b/backend_py/primary/primary/routers/surface/router.py @@ -2,7 +2,7 @@ import logging from typing import Annotated, List, Optional, Literal -from fastapi import APIRouter, Depends, HTTPException, Query, Response, Body, status +from fastapi import APIRouter, Depends, HTTPException, Query, Response, Request, Body, status from webviz_pkg.core_utils.perf_metrics import PerfMetrics from primary.services.sumo_access.case_inspector import CaseInspector @@ -231,6 +231,7 @@ async def post_get_surface_intersection( @router.post("/sample_surface_in_points") async def post_sample_surface_in_points( + request: Request, case_uuid: str = Query(description="Sumo case uuid"), ensemble_name: str = Query(description="Ensemble name"), surface_name: str = Query(description="Surface name"), @@ -242,7 +243,10 @@ async def post_sample_surface_in_points( sumo_access_token = authenticated_user.get_sumo_access_token() + async_client = request.app.state.requests_client + result_arr: List[RealizationSampleResult] = await batch_sample_surface_in_points_async( + async_client=async_client, sumo_access_token=sumo_access_token, case_uuid=case_uuid, iteration_name=ensemble_name, diff --git a/backend_py/primary/primary/services/surface_query_service/surface_query_service.py b/backend_py/primary/primary/services/surface_query_service/surface_query_service.py index 6485f895e..9c8893c76 100644 --- a/backend_py/primary/primary/services/surface_query_service/surface_query_service.py +++ b/backend_py/primary/primary/services/surface_query_service/surface_query_service.py @@ -48,6 +48,7 @@ class _PointSamplingResponseBody(BaseModel): async def batch_sample_surface_in_points_async( + async_client: httpx.AsyncClient, sumo_access_token: str, case_uuid: str, iteration_name: str, @@ -117,14 +118,19 @@ async def batch_sample_surface_in_points_async( json_request_body = request_body.model_dump() - async with httpx.AsyncClient(timeout=300) as client: - LOGGER.info(f"Running async go point sampling for surface: {surface_name}") + LOGGER.info(f"Running async go point sampling for surface: {surface_name}") + perf_metrics.record_lap("prepare_call") - perf_metrics.record_lap("prepare_call") - response: httpx.Response = await client.post( - url=SERVICE_ENDPOINT, json=json_request_body - ) + response: httpx.Response = await async_client.post(url=SERVICE_ENDPOINT, json=json_request_body) + + + # async with httpx.AsyncClient(timeout=300) as client: + # LOGGER.info(f"Running async go point sampling for surface: {surface_name}") + + # perf_metrics.record_lap("prepare_call") + + # response: httpx.Response = await client.post(url=SERVICE_ENDPOINT, json=json_request_body) perf_metrics.record_lap("main-call")