Skip to content

Commit

Permalink
let's try multiprocessing for storage grpc
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxiBoether committed Sep 27, 2023
1 parent 0c3745a commit 79451f3
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 16 deletions.
83 changes: 67 additions & 16 deletions modyn/storage/internal/grpc/grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
"""GRPC server context manager."""

import contextlib
import datetime
import logging
import multiprocessing as mp
import os
import socket
import time
from concurrent import futures
from typing import Any

import grpc
from modyn.storage.internal.grpc.generated.storage_pb2_grpc import add_StorageServicer_to_server
Expand All @@ -11,6 +18,49 @@
logger = logging.getLogger(__name__)


@contextlib.contextmanager
def _reserve_port(port: str):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", int(port)))
try:
assert sock.getsockname()[1] == int(port)
yield port
finally:
sock.close()


def _wait_forever(server):
try:
while True:
time.sleep(datetime.timedelta(days=1).total_seconds())
except KeyboardInterrupt:
server.stop(None)


def _run_server(bind_address, modyn_config):
"""Start a server in a subprocess."""
logging.info(f"[{os.getpid()}] Starting new server.")

server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=16,
),
options=[
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
("grpc.so_reuseport", 1),
],
)
add_StorageServicer_to_server(StorageGRPCServicer(modyn_config), server)
server.add_insecure_port(bind_address)
server.start()
_wait_forever(server)


class GRPCServer:
"""GRPC server context manager."""

Expand All @@ -21,28 +71,28 @@ def __init__(self, modyn_config: dict) -> None:
modyn_config (dict): Configuration of the storage module.
"""
self.modyn_config = modyn_config
self.server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=64,
),
options=[
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
],
)

def __enter__(self) -> grpc.Server:
self.workers = []

def __enter__(self) -> Any:
"""Enter the context manager.
Returns:
grpc.Server: GRPC server
"""
add_StorageServicer_to_server(StorageGRPCServicer(self.modyn_config), self.server)
port = self.modyn_config["storage"]["port"]
logger.info(f"Starting server. Listening on port {port}")
self.server.add_insecure_port("[::]:" + port)
self.server.start()
return self.server
with _reserve_port(port) as port:
bind_address = "[::]:" + port
for _ in range(64):
worker = mp.Process(target=_run_server, args=(bind_address,self.modyn_config,))
worker.start()
self.workers.append(worker)

return self

def wait_for_termination(self) -> None:
for worker in self.workers:
worker.join()

def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None:
"""Exit the context manager.
Expand All @@ -52,4 +102,5 @@ def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> Non
exc_val (Exception): exception value
exc_tb (Exception): exception traceback
"""
self.server.stop(0)
self.wait_for_termination()
del self.workers
10 changes: 10 additions & 0 deletions modyn/storage/internal/grpc/storage_grpc_servicer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Storage GRPC servicer."""

import logging
import os
import threading
from typing import Iterable, Tuple

import grpc
from modyn.common.benchmark.stopwatch import Stopwatch
from modyn.storage.internal.database.models import Dataset, File, Sample
from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection
from modyn.storage.internal.database.storage_database_utils import get_file_wrapper, get_filesystem_wrapper
Expand Down Expand Up @@ -64,6 +67,9 @@ def Get(self, request: GetRequest, context: grpc.ServicerContext) -> Iterable[Ge
Yields:
Iterator[Iterable[GetResponse]]: Response containing the data for the given keys.
"""
tid = threading.get_native_id()
pid = os.getpid()
logger.info(f"[{pid}][{tid}] Received request for {len(request.keys)} items.")
with StorageDatabaseConnection(self.modyn_config) as database:
session = database.session

Expand All @@ -73,12 +79,16 @@ def Get(self, request: GetRequest, context: grpc.ServicerContext) -> Iterable[Ge
yield GetResponse()
return

stopw = Stopwatch()
stopw.start("GetSamples")
samples: list[Sample] = (
session.query(Sample)
.filter(and_(Sample.sample_id.in_(request.keys), Sample.dataset_id == dataset.dataset_id))
.order_by(Sample.file_id)
.all()
)
samples_time = stopw.stop()
logger.info(f"[{pid}][{tid}] Getting samples took {samples_time / 1000}s.")

if len(samples) == 0:
logger.error("No samples found in the database.")
Expand Down

0 comments on commit 79451f3

Please sign in to comment.