Skip to content

Commit

Permalink
refactor to generic grpc server
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxiBoether committed Oct 2, 2023
1 parent d097bf5 commit 15ec797
Show file tree
Hide file tree
Showing 14 changed files with 198 additions and 166 deletions.
10 changes: 10 additions & 0 deletions modyn/common/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
This submodule implements functions to run gRPC servers using multiprocessing.
"""
import os

from .grpc_helpers import GenericGRPCServer # noqa: F401

files = os.listdir(os.path.dirname(__file__))
files.remove("__init__.py")
__all__ = [f[:-3] for f in files if f.endswith(".py")]
107 changes: 107 additions & 0 deletions modyn/common/grpc/grpc_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import contextlib
import datetime
import logging
import multiprocessing as mp
import os
import socket
import time
from concurrent import futures
from typing import Any, Callable

import grpc
from modyn.utils import MAX_MESSAGE_SIZE

logger = logging.getLogger(__name__)

PROCESS_THREAD_WORKERS = 16
NUM_GPRC_PROCESSES = 64


@contextlib.contextmanager
def reserve_port(port: str):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", int(port)))
try:
assert sock.getsockname()[1] == int(port)
yield port
finally:
sock.close()


def _wait_forever(server):
try:
while True:
time.sleep(datetime.timedelta(days=1).total_seconds())
except KeyboardInterrupt:
server.stop(None)


def _run_server_worker(bind_address: str, add_servicer_callback: Callable):
"""Start a server in a subprocess."""
logging.info(f"[{os.getpid()}] Starting new gRPC server process.")

server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=PROCESS_THREAD_WORKERS,
),
options=[
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
("grpc.so_reuseport", 1),
],
)

add_servicer_callback(server)
server.add_insecure_port(bind_address)
server.start()
_wait_forever(server)


class GenericGRPCServer:
def __init__(self, port: str, add_servicer_callback: Callable) -> None:
"""Initialize the GRPC server.
Args:
TODO
"""
self.port = port
self.add_servicer_callback = add_servicer_callback
self.workers = []

def __enter__(self) -> Any:
"""Enter the context manager.
Returns:
grpc.Server: GRPC server
"""
logger.info(f"[{os.getpid()}] Starting server. Listening on port {self.port}")
with reserve_port(self.port) as port:
bind_address = "[::]:" + port
for _ in range(NUM_GPRC_PROCESSES):
worker = mp.Process(
target=_run_server_worker,
args=(bind_address, self.add_servicer_callback),
)
worker.start()
self.workers.append(worker)

return self

def wait_for_termination(self) -> None:
for worker in self.workers:
worker.join()

def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None:
"""Exit the context manager.
Args:
exc_type (type): exception type
exc_val (Exception): exception value
exc_tb (Exception): exception traceback
"""
self.wait_for_termination()
del self.workers
7 changes: 6 additions & 1 deletion modyn/selector/internal/grpc/selector_grpc_servicer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
import os
import threading
from typing import Iterable

import grpc
Expand Down Expand Up @@ -59,8 +61,11 @@ def get_sample_keys_and_weights( # pylint: disable-next=unused-argument
request.worker_id,
request.partition_id,
)
tid = threading.get_native_id()
pid = os.getpid()

logger.info(
f"[Pipeline {pipeline_id}]: Fetching samples for trigger id {trigger_id}"
f"[{pid}][{tid}][Pipeline {pipeline_id}]: Fetching samples for trigger id {trigger_id}"
+ f" and worker id {worker_id} and partition id {partition_id}"
)

Expand Down
73 changes: 14 additions & 59 deletions modyn/selector/internal/grpc/selector_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from concurrent import futures

import grpc
from modyn.common.grpc import GenericGRPCServer
from modyn.selector.internal.grpc.generated.selector_pb2_grpc import add_SelectorServicer_to_server # noqa: E402, E501
from modyn.selector.internal.grpc.selector_grpc_servicer import SelectorGRPCServicer
from modyn.selector.internal.selector_manager import SelectorManager
Expand All @@ -16,72 +17,26 @@
logger = logging.getLogger(__name__)


@contextlib.contextmanager
def _reserve_port(port: str):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", int(port)))
try:
assert sock.getsockname()[1] == int(port)
yield port
finally:
sock.close()


def _wait_forever(server):
try:
while True:
time.sleep(datetime.timedelta(days=1).total_seconds())
except KeyboardInterrupt:
server.stop(None)


def _run_server(bind_address, selector_manager, sample_batch_size):
"""Start a server in a subprocess."""
logging.info(f"[{os.getpid()}] Starting new server.")

server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=16,
),
options=[
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
("grpc.so_reuseport", 1),
],
)
add_SelectorServicer_to_server(SelectorGRPCServicer(selector_manager, sample_batch_size), server)
server.add_insecure_port(bind_address)
server.start()
_wait_forever(server)


class SelectorServer:
class SelectorGRPCServer(GenericGRPCServer):
def __init__(self, modyn_config: dict) -> None:
self.modyn_config = modyn_config
self.selector_manager = SelectorManager(modyn_config)
self.sample_batch_size = self.modyn_config["selector"]["sample_batch_size"]
self.workers = []

def run(self) -> None:
port = self.modyn_config["selector"]["port"]
logger.info(f"Starting server. Listening on port {port}")
with _reserve_port(port) as port:
bind_address = "[::]:" + port
for _ in range(64):
worker = mp.Process(
target=_run_server,
args=(bind_address, self.selector_manager, self.sample_batch_size),
)
worker.start()
self.workers.append(worker)
def callback(server):
add_SelectorServicer_to_server(SelectorGRPCServicer(self.selector_manager, self.sample_batch_size), server)

super().__init__(self.modyn_config["storage"]["port"], callback)

for worker in self.workers:
worker.join()
def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None:
"""Exit the context manager.
Args:
exc_type (type): exception type
exc_val (Exception): exception value
exc_tb (Exception): exception traceback
"""
super().__exit__(exc_type, exc_val, exc_tb)
if (
"cleanup_trigger_samples_after_shutdown" in self.modyn_config["selector"]
and self.modyn_config["selector"]["cleanup_trigger_samples_after_shutdown"]
Expand Down
2 changes: 1 addition & 1 deletion modyn/selector/internal/selector_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _populate_pipeline_if_exists(self, pipeline_id: int) -> None:
return
logging.info(
f"[{os.getpid()}] Instantiating new selector for pipeline {pipeline_id}"
+ " that was in the DB but previously unknown to this process.."
+ " that was in the DB but previously unknown to this process"
)

self._instantiate_selector(pipeline_id, pipeline.num_workers, pipeline.selection_strategy)
Expand Down
2 changes: 1 addition & 1 deletion modyn/selector/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _populate_trigger_if_exists(self, trigger_id: int) -> None:
return

with MetadataDatabaseConnection(self._modyn_config) as database:
trigger: Optional[Trigger] = database.session.get(Trigger, trigger_id, self._pipeline_id)
trigger: Optional[Trigger] = database.session.get(Trigger, (trigger_id, self._pipeline_id))
if trigger is None:
return

Expand Down
18 changes: 14 additions & 4 deletions modyn/selector/selector_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import logging
import multiprocessing as mp
import os
import pathlib

import yaml
from modyn.selector.internal.grpc.selector_server import SelectorServer
from modyn.selector.internal.grpc.selector_server import SelectorGRPCServer

logging.basicConfig(
level=logging.NOTSET,
Expand All @@ -12,6 +14,14 @@
)
logger = logging.getLogger(__name__)

# We need to do this at the top because other dependencies otherwise set fork.
try:
mp.set_start_method("spawn")
except RuntimeError as error:
if mp.get_start_method() != "spawn" and "PYTEST_CURRENT_TEST" not in os.environ:
logger.error("Start method is already set to {}", mp.get_start_method())
raise error


def setup_argparser() -> argparse.ArgumentParser:
parser_ = argparse.ArgumentParser(description="Modyn Selector")
Expand All @@ -35,9 +45,9 @@ def main() -> None:
modyn_config = yaml.safe_load(config_file)

logger.info("Initializing selector server.")
selector = SelectorServer(modyn_config)
logger.info("Starting selector server.")
selector.run()

with SelectorGRPCServer(modyn_config):
pass

logger.info("Selector server returned, exiting.")

Expand Down
86 changes: 5 additions & 81 deletions modyn/storage/internal/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,57 +11,15 @@
from typing import Any

import grpc
from modyn.common.grpc import GenericGRPCServer
from modyn.storage.internal.grpc.generated.storage_pb2_grpc import add_StorageServicer_to_server
from modyn.storage.internal.grpc.storage_grpc_servicer import StorageGRPCServicer
from modyn.utils import MAX_MESSAGE_SIZE

logger = logging.getLogger(__name__)


@contextlib.contextmanager
def _reserve_port(port: str):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", int(port)))
try:
assert sock.getsockname()[1] == int(port)
yield port
finally:
sock.close()


def _wait_forever(server):
try:
while True:
time.sleep(datetime.timedelta(days=1).total_seconds())
except KeyboardInterrupt:
server.stop(None)


def _run_server(bind_address, modyn_config):
"""Start a server in a subprocess."""
logging.info(f"[{os.getpid()}] Starting new server.")

server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=16,
),
options=[
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
("grpc.so_reuseport", 1),
],
)
add_StorageServicer_to_server(StorageGRPCServicer(modyn_config), server)
server.add_insecure_port(bind_address)
server.start()
_wait_forever(server)


class GRPCServer:
class StorageGRPCServer(GenericGRPCServer):
"""GRPC server context manager."""

def __init__(self, modyn_config: dict) -> None:
Expand All @@ -71,42 +29,8 @@ def __init__(self, modyn_config: dict) -> None:
modyn_config (dict): Configuration of the storage module.
"""
self.modyn_config = modyn_config
self.workers = []

def __enter__(self) -> Any:
"""Enter the context manager.
def callback(server):
add_StorageServicer_to_server(StorageGRPCServicer(modyn_config), server)

Returns:
grpc.Server: GRPC server
"""
port = self.modyn_config["storage"]["port"]
logger.info(f"Starting server. Listening on port {port}")
with _reserve_port(port) as port:
bind_address = "[::]:" + port
for _ in range(64):
worker = mp.Process(
target=_run_server,
args=(
bind_address,
self.modyn_config,
),
)
worker.start()
self.workers.append(worker)

return self

def wait_for_termination(self) -> None:
for worker in self.workers:
worker.join()

def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None:
"""Exit the context manager.
Args:
exc_type (type): exception type
exc_val (Exception): exception value
exc_tb (Exception): exception traceback
"""
self.wait_for_termination()
del self.workers
super().__init__(self.modyn_config["storage"]["port"], callback)
Loading

0 comments on commit 15ec797

Please sign in to comment.