Skip to content

Commit

Permalink
Experiment with in-mem caching
Browse files Browse the repository at this point in the history
  • Loading branch information
sigurdp committed Oct 16, 2023
1 parent 94f7481 commit 3e094ad
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 1 deletion.
156 changes: 156 additions & 0 deletions backend/src/backend/experiments/calc_surf_isec_inmem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import signal
import numpy as np
import logging
from typing import List
import multiprocessing
import xtgeo
import io

# from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass

from src.services.sumo_access.surface_access import SurfaceAccess
from src.services.utils.authenticated_user import AuthenticatedUser
from src.backend.primary.routers.surface import schemas
from src.backend.utils.perf_metrics import PerfMetrics

LOGGER = logging.getLogger(__name__)


@dataclass
class SurfCacheEntry:
surf: xtgeo.RegularSurface | None

class InMemSurfCache:
def __init__(self):
self._dict = {}

def set(self, case_uuid: str, ensemble_name: str, name: str, attribute: str, real: int, cache_entry: SurfCacheEntry):
key = f"{case_uuid}:{ensemble_name}:{name}:{attribute}:{real}"
self._dict[key] = cache_entry

def get(self, case_uuid: str, ensemble_name: str, name: str, attribute: str, real: int) -> SurfCacheEntry | None:
key = f"{case_uuid}:{ensemble_name}:{name}:{attribute}:{real}"
surf = self._dict.get(key)
return surf


IN_MEM_SURF_CACHE = InMemSurfCache()


@dataclass
class SurfItem:
# access_token: str
case_uuid: str
ensemble_name: str
name: str
attribute: str
real: int
# fence_arr: np.ndarray


@dataclass
class ResultItem:
perf_info: str
line: np.ndarray


global_access = None


def init_access(access_token: str, case_uuid: str, ensemble_name: str):
# !!!!!!!!!!!!!
# See: https://github.com/tiangolo/fastapi/issues/1487#issuecomment-1157066306
signal.set_wakeup_fd(-1)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)

global global_access
global_access = SurfaceAccess.from_case_uuid_sync(access_token, case_uuid, ensemble_name)


def fetch_a_surf(item: SurfItem) -> bytes:
print(f">>>> fetch_a_surf {item.real=}", flush=True)
perf_metrics = PerfMetrics()

access = global_access
# access = await SurfaceAccess.from_case_uuid(item.access_token, item.case_uuid, item.ensemble_name)
perf_metrics.record_lap("access")

surf_bytes = access.get_realization_surface_bytes_sync(real_num=item.real, name=item.name, attribute=item.attribute)
if surf_bytes is None:
return None
perf_metrics.record_lap("fetch")

print(f">>>> fetch_a_surf {item.real=} done", flush=True)

return surf_bytes


async def calc_surf_isec_inmem(
perf_metrics: PerfMetrics,
authenticated_user: AuthenticatedUser,
case_uuid: str,
ensemble_name: str,
name: str,
attribute: str,
num_reals: int,
cutting_plane: schemas.CuttingPlane,
) -> List[schemas.SurfaceIntersectionData]:
myprefix = ">>>>>>>>>>>>>>>>> calc_surf_isec_inmem():"
print(f"{myprefix} started", flush=True)

fence_arr = np.array(
[cutting_plane.x_arr, cutting_plane.y_arr, np.zeros(len(cutting_plane.y_arr)), cutting_plane.length_arr]
).T

access_token = authenticated_user.get_sumo_access_token()

reals = range(0, num_reals)

xtgeo_surf_arr = []
items_to_fetch_list = []

for real in reals:
cache_entry = IN_MEM_SURF_CACHE.get(case_uuid, ensemble_name, name, attribute, real)
if cache_entry is not None:
xtgeo_surf_arr.append(cache_entry.surf)
else:
items_to_fetch_list.append(
SurfItem(
case_uuid=case_uuid,
ensemble_name=ensemble_name,
name=name,
attribute=attribute,
real=real,
)
)

print(f"{myprefix} {len(xtgeo_surf_arr)=}", flush=True)
print(f"{myprefix} {len(items_to_fetch_list)=}", flush=True)


if len(items_to_fetch_list) > 0:
with multiprocessing.Pool(initializer=init_access, initargs=(access_token, case_uuid, ensemble_name)) as pool:
res_item_arr = pool.map(fetch_a_surf, items_to_fetch_list)
print(f"{myprefix} back from map {len(res_item_arr)=}", flush=True)

for idx, res_item in enumerate(res_item_arr):
xtgeo_surf = None
if res_item is not None:
print(f"{myprefix} {type(res_item)=}", flush=True)
byte_stream = io.BytesIO(res_item)
xtgeo_surf = xtgeo.surface_from_file(byte_stream)

xtgeo_surf_arr.append(xtgeo_surf)
IN_MEM_SURF_CACHE.set(case_uuid, ensemble_name, items_to_fetch_list[idx].name, items_to_fetch_list[idx].attribute, items_to_fetch_list[idx].real, cache_entry=SurfCacheEntry(xtgeo_surf))

intersections = []

for xtgeo_surf in xtgeo_surf_arr:
if (xtgeo_surf):
line = xtgeo_surf.get_randomline(fence_arr)
intersections.append(schemas.SurfaceIntersectionData(name="someName", hlen_arr=line[:, 0].tolist(), z_arr=line[:, 1].tolist()))

return intersections
4 changes: 3 additions & 1 deletion backend/src/backend/user_session/routers/surface/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from src.backend.experiments.calc_surf_isec_multiprocess import calc_surf_isec_multiprocess
from src.backend.experiments.calc_surf_isec_aiomultiproc import calc_surf_isec_aiomultiproc
from src.backend.experiments.calc_surf_isec_custom import calc_surf_isec_custom
from src.backend.experiments.calc_surf_isec_inmem import calc_surf_isec_inmem


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -61,8 +62,9 @@ async def post_calc_surf_isec_experiments(
#intersections = await calc_surf_isec_fetch_first(perf_metrics, authenticated_user, case_uuid, ensemble_name, name, attribute, num_reals, cutting_plane)
#intersections = await calc_surf_isec_queue(perf_metrics, authenticated_user, case_uuid, ensemble_name, name, attribute, num_reals, num_workers, cutting_plane)
#intersections = await calc_surf_isec_multiprocess(perf_metrics, authenticated_user, case_uuid, ensemble_name, name, attribute, num_reals, cutting_plane)
intersections = await calc_surf_isec_aiomultiproc(authenticated_user, case_uuid, ensemble_name, name, attribute, num_reals, cutting_plane)
#intersections = await calc_surf_isec_aiomultiproc(authenticated_user, case_uuid, ensemble_name, name, attribute, num_reals, cutting_plane)
#intersections = await calc_surf_isec_custom(perf_metrics, authenticated_user, case_uuid, ensemble_name, name, attribute, num_reals, num_workers, cutting_plane)
intersections = await calc_surf_isec_inmem(perf_metrics, authenticated_user, case_uuid, ensemble_name, name, attribute, num_reals, cutting_plane)

LOGGER.debug(f"route calc_surf_isec_experiments - intersected {len(intersections)} surfaces in: {perf_metrics.to_string()}")

Expand Down
66 changes: 66 additions & 0 deletions backend/src/services/sumo_access/surface_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,72 @@ def get_realization_surface_data_sync(

return xtgeo_surf

def get_realization_surface_bytes_sync(
self, real_num: int, name: str, attribute: str, time_or_interval_str: Optional[str] = None
) -> Optional[xtgeo.RegularSurface]:
"""
Get surface data for a realization surface
"""
timer = PerfTimer()
addr_str = self._make_addr_str(real_num, name, attribute, time_or_interval_str)

if time_or_interval_str is None:
time_filter = TimeFilter(TimeType.NONE)

else:
timestamp_arr = time_or_interval_str.split("/", 1)
if len(timestamp_arr) == 0 or len(timestamp_arr) > 2:
raise ValueError("time_or_interval_str must contain a single timestamp or interval")
if len(timestamp_arr) == 1:
time_filter = TimeFilter(
TimeType.TIMESTAMP,
start=timestamp_arr[0],
end=timestamp_arr[0],
exact=True,
)
else:
time_filter = TimeFilter(
TimeType.INTERVAL,
start=timestamp_arr[0],
end=timestamp_arr[1],
exact=True,
)

surface_collection: SurfaceCollection = self._case.surfaces.filter(
iteration=self._iteration_name,
aggregation=False,
realization=real_num,
name=name,
tagname=attribute,
time=time_filter,
)

surf_count = len(surface_collection)
if surf_count == 0:
LOGGER.warning(f"No realization surface found in Sumo for {addr_str}")
return None
if surf_count > 1:
LOGGER.warning(f"Multiple ({surf_count}) surfaces found in Sumo for: {addr_str}. Returning first surface.")

sumo_surf = surface_collection[0]
et_locate_ms = timer.lap_ms()

surf_bytes: bytes = self._sumo_client.get(f"/objects('{sumo_surf.uuid}')/blob")
et_download_ms = timer.lap_ms()

size_mb = len(surf_bytes)/(1024*1024)

LOGGER.debug(
f"Got realization surface bytes from Sumo in: {timer.elapsed_ms()}ms ("
f"locate={et_locate_ms}ms, "
f"download={et_download_ms}ms, "
f"[{size_mb:.2f}MB] "
f"({addr_str})"
)

return surf_bytes


async def get_realization_surface_data_async(
self, real_num: int, name: str, attribute: str, time_or_interval_str: Optional[str] = None
) -> Optional[xtgeo.RegularSurface]:
Expand Down

0 comments on commit 3e094ad

Please sign in to comment.