Skip to content

Commit

Permalink
fix(multiprocess_predictor): cancel requests when task cancelled
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 30, 2025
1 parent c20e9bd commit a43de11
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 77 deletions.
16 changes: 9 additions & 7 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,19 @@ def on_run_transformations(self, id_to_image):
self.transformed_images.add_images(id_to_image)

def update_transformed_images(self, id_to_image):
new_to_plot = {
id: img
for id, img in id_to_image.items()
ids_to_plot = [
id
for id in id_to_image.keys()
if image_id_to_dataset_id(id) not in self._stashed_points_transformations
}
]
images_to_plot = (id_to_image[id] for id in ids_to_plot)

transformation_features = self.extractor.extract(list(new_to_plot.values()))
transformation_features = self.extractor.extract(images_to_plot)
points = self.compute_points(self.features, transformation_features)
image_id_to_point = zip(new_to_plot.keys(), points)

updated_points = {image_id_to_dataset_id(id): point for id, point in image_id_to_point}
updated_points = {
image_id_to_dataset_id(id): point for id, point in zip(ids_to_plot, points)
}
self._stashed_points_transformations = {
**self._stashed_points_transformations,
**updated_points,
Expand Down
24 changes: 4 additions & 20 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
import base64
import psutil
from io import BytesIO
from PIL import Image
from trame.decorators import TrameApp, change, controller
from nrtk_explorer.app.images.image_ids import (
dataset_id_to_image_id,
dataset_id_to_transformed_image_id,
)
from nrtk_explorer.app.trame_utils import delete_state
from nrtk_explorer.app.images.cache import LruCache
from nrtk_explorer.library.transforms import ImageTransform


def convert_to_base64(img: Image.Image) -> str:
"""Convert image to base64 string"""
buf = BytesIO()
img.save(buf, format="png")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


IMAGE_CACHE_SIZE_DEFAULT = 50
AVALIBLE_MEMORY_TO_TAKE_FACTOR = 0.3

Expand All @@ -29,7 +19,7 @@ def __init__(self, server):
self.server = server
self.original_images = LruCache(IMAGE_CACHE_SIZE_DEFAULT)
self.transformed_images = LruCache(IMAGE_CACHE_SIZE_DEFAULT)
self._should_reset_cache = True
self._should_ajust_cache_size = True

def _ajust_cache_size(self, image_example: Image.Image):
img_size = len(image_example.tobytes())
Expand All @@ -44,8 +34,8 @@ def _load_image(self, dataset_id: str):
img = self.server.context.dataset.get_image(int(dataset_id))
img.load() # Avoid OSError(24, 'Too many open files')

if self._should_reset_cache:
self._should_reset_cache = False
if self._should_ajust_cache_size:
self._should_ajust_cache_size = False
self._ajust_cache_size(img) # assuming images in dataset are similar size

# transforms and base64 encoding require RGB mode
Expand All @@ -58,12 +48,6 @@ def get_image(self, dataset_id: str, **kwargs):
self.original_images.add_item(image_id, image, **kwargs)
return image

def _add_image_to_state(self, image_id: str, image: Image.Image):
self.server.state[image_id] = convert_to_base64(image)

def _delete_from_state(self, state_key: str):
delete_state(self.server.state, state_key)

def get_image_without_cache_eviction(self, dataset_id: str):
"""
Does not remove items from cache, only adds.
Expand Down Expand Up @@ -105,7 +89,7 @@ def get_transformed_image_without_cache_eviction(
def clear_all(self, **kwargs):
self.original_images.clear()
self.clear_transformed()
self._should_reset_cache = True
self._should_ajust_cache_size = True

@controller.add("apply_transform")
def clear_transformed(self, **kwargs):
Expand Down
137 changes: 87 additions & 50 deletions src/nrtk_explorer/library/multiprocess_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from enum import Enum
from .predictor import Predictor

WORKER_RESPONSE_TIMEOUT = 40


class Command(Enum):
SET_MODEL = "SET_MODEL"
Expand All @@ -35,7 +33,6 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
command = Command(msg["command"])
req_id = msg["req_id"]
payload = msg.get("payload", {})
logger.debug(f"Worker: Received {command.value} with ID {req_id}")

if command == Command.SET_MODEL:
try:
Expand All @@ -61,6 +58,11 @@ def _child_worker(request_queue, result_queue, model_name, force_cpu):
logger.exception("Reset failed.")
result_queue.put((req_id, {"status": "ERROR", "message": str(e)}))

del msg
del command
del req_id
del payload

logger.debug("Worker: shutting down.")


Expand All @@ -72,8 +74,15 @@ def __init__(self, model_name="facebook/detr-resnet-50", force_cpu=False):
self._proc = None
self._request_queue = None
self._result_queue = None
self._pending_futures = {}

self.loop = asyncio.get_event_loop()

self._start_process()

# Instead of a response thread, schedule an async task:
asyncio.ensure_future(self._poll_responses())

def handle_shutdown(signum, frame):
self.shutdown()

Expand All @@ -98,66 +107,94 @@ def _start_process(self):
)
self._proc.start()

def _get_response(self, req_id, timeout=WORKER_RESPONSE_TIMEOUT):
async def _poll_responses(self):
while True:
try:
r_id, data = self._result_queue.get(timeout=timeout)
except queue.Empty:
raise TimeoutError("No response from worker.")
if r_id == req_id:
return data

def _wait_for_response(self, req_id):
return self._get_response(req_id, WORKER_RESPONSE_TIMEOUT)

async def _wait_for_response_async(self, req_id):
return await asyncio.get_event_loop().run_in_executor(
None, self._get_response, req_id, WORKER_RESPONSE_TIMEOUT
)
r_id, payload = await self.loop.run_in_executor(None, self._result_queue.get)
except (EOFError, KeyboardInterrupt):
break
with self._lock:
future = self._pending_futures.pop(r_id, None)
if future and not future.done():
future.set_result(payload)

async def _submit_request(self, command, payload):
future = self.loop.create_future()
req_id = str(uuid.uuid4())

def cleanup(_):
with self._lock:
self._pending_futures.pop(req_id, None)
# Remove the request if it's still in the queue. Probably got canceled.
stashed_requests = []
while not self._request_queue.empty():
try:
req = self._request_queue.get_nowait()
if req["req_id"] != req_id:
stashed_requests.append(req)
except queue.Empty:
break
for req in stashed_requests:
self._request_queue.put(req)

future.add_done_callback(cleanup)

def set_model(self, model_name, force_cpu=False):
with self._lock:
self.model_name = model_name
self.force_cpu = force_cpu
req_id = str(uuid.uuid4())
self._request_queue.put(
{
"command": Command.SET_MODEL.value,
"req_id": req_id,
"payload": {
"model_name": self.model_name,
"force_cpu": self.force_cpu,
},
}
)
return self._wait_for_response(req_id)
self._pending_futures[req_id] = future

self._request_queue.put(
{
"command": command.value,
"req_id": req_id,
"payload": payload,
}
)

try:
r = await future
return r
except asyncio.CancelledError:
cleanup(None)
raise

async def infer(self, images):
if not images:
return {}
resp = await self._submit_request(Command.INFER, {"images": images})
return resp.get("result")

def _run_coro(self, coro):
if self.loop.is_running():
return asyncio.ensure_future(coro)
else:
return self.loop.run_until_complete(coro)

def set_model(self, model_name, force_cpu=False):
with self._lock:
req_id = str(uuid.uuid4())
new_req = {
"command": Command.INFER.value,
"req_id": req_id,
"payload": {"images": images},
}
self._request_queue.put(new_req)
self.model_name = model_name
self.force_cpu = force_cpu

resp = await self._wait_for_response_async(req_id)
return resp.get("result")
async def _async_set():
return await self._submit_request(
Command.SET_MODEL, {"model_name": self.model_name, "force_cpu": self.force_cpu}
)

return self._run_coro(_async_set())

def reset(self):
with self._lock:
req_id = str(uuid.uuid4())
self._request_queue.put({"command": Command.RESET.value, "req_id": req_id})
return self._wait_for_response(req_id)
async def _async_reset():
return await self._submit_request(Command.RESET, {})

return self._run_coro(_async_reset())

def shutdown(self):
with self._lock:
try:
self._request_queue.put(None)
except Exception:
logging.warning("Could not send exit message to worker.")
async def _async_shutdown():
with self._lock:
try:
self._request_queue.put(None)
except Exception:
logging.warning("Could not send exit message to worker.")
if self._proc:
self._proc.join()

return self._run_coro(_async_shutdown())

0 comments on commit a43de11

Please sign in to comment.