Skip to content

Commit

Permalink
refactor to generic grpc servert
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxiBoether committed Oct 2, 2023
1 parent d097bf5 commit 5b2c4fe
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 168 deletions.
10 changes: 10 additions & 0 deletions modyn/common/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
This submodule implements functions to run gRPC servers using multiprocessing.
"""
import os

from .grpc_helpers import GenericGRPCServer # noqa: F401

files = os.listdir(os.path.dirname(__file__))
files.remove("__init__.py")
__all__ = [f[:-3] for f in files if f.endswith(".py")]
120 changes: 120 additions & 0 deletions modyn/common/grpc/grpc_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import contextlib
import datetime
import logging
import multiprocessing as mp
import os
import pickle
import socket
import time
from concurrent import futures
from typing import Any, Callable

import grpc
from modyn.utils import MAX_MESSAGE_SIZE

logger = logging.getLogger(__name__)

PROCESS_THREAD_WORKERS = 16
NUM_GPRC_PROCESSES = 64


@contextlib.contextmanager
def reserve_port(port: str):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", int(port)))
try:
assert sock.getsockname()[1] == int(port)
yield port
finally:
sock.close()


def _wait_forever(server):
try:
while True:
time.sleep(datetime.timedelta(days=1).total_seconds())
except KeyboardInterrupt:
server.stop(None)


def _run_server_worker(bind_address: str, add_servicer_callback: Callable, modyn_config: dict):
"""Start a server in a subprocess."""
logging.info(f"[{os.getpid()}] Starting new gRPC server process.")

server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=PROCESS_THREAD_WORKERS,
),
options=[
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
("grpc.so_reuseport", 1),
],
)

add_servicer_callback(modyn_config, server)
server.add_insecure_port(bind_address)
server.start()
_wait_forever(server)


class GenericGRPCServer:
def __init__(self, modyn_config: dict, port: str, add_servicer_callback: Callable) -> None:
"""Initialize the GRPC server.
Args:
TODO
"""
self.port = port
self.modyn_config = modyn_config
self.add_servicer_callback = add_servicer_callback
self.workers = []

def __enter__(self) -> Any:
"""Enter the context manager.
Returns:
grpc.Server: GRPC server
"""
logger.info(f"[{os.getpid()}] Starting server. Listening on port {self.port}")
with reserve_port(self.port) as port:
bind_address = "[::]:" + port
for _ in range(NUM_GPRC_PROCESSES):
worker = mp.Process(
target=_run_server_worker,
args=(bind_address, self.add_servicer_callback),
)
worker.start()
self.workers.append(worker)

return self

def __getstate__(self):
for variable_name, value in vars(self).items():
try:
pickle.dumps(value)
except:
print(f'{variable_name} with value {value} is not pickable')

state = self.__dict__.copy()
del state["add_servicer_callback"]
return state

def wait_for_termination(self) -> None:
for worker in self.workers:
worker.join()

def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None:
"""Exit the context manager.
Args:
exc_type (type): exception type
exc_val (Exception): exception value
exc_tb (Exception): exception traceback
"""
self.wait_for_termination()
del self.workers
7 changes: 6 additions & 1 deletion modyn/selector/internal/grpc/selector_grpc_servicer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
import os
import threading
from typing import Iterable

import grpc
Expand Down Expand Up @@ -59,8 +61,11 @@ def get_sample_keys_and_weights( # pylint: disable-next=unused-argument
request.worker_id,
request.partition_id,
)
tid = threading.get_native_id()
pid = os.getpid()

logger.info(
f"[Pipeline {pipeline_id}]: Fetching samples for trigger id {trigger_id}"
f"[{pid}][{tid}][Pipeline {pipeline_id}]: Fetching samples for trigger id {trigger_id}"
+ f" and worker id {worker_id} and partition id {partition_id}"
)

Expand Down
90 changes: 32 additions & 58 deletions modyn/selector/internal/grpc/selector_server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import contextlib
import datetime
import functools
import logging
import multiprocessing as mp
import os
import pickle
import socket
import time
from concurrent import futures

import grpc
from modyn.common.grpc import GenericGRPCServer
from modyn.selector.internal.grpc.generated.selector_pb2_grpc import add_SelectorServicer_to_server # noqa: E402, E501
from modyn.selector.internal.grpc.selector_grpc_servicer import SelectorGRPCServicer
from modyn.selector.internal.selector_manager import SelectorManager
Expand All @@ -16,72 +19,43 @@
logger = logging.getLogger(__name__)


@contextlib.contextmanager
def _reserve_port(port: str):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", int(port)))
try:
assert sock.getsockname()[1] == int(port)
yield port
finally:
sock.close()


def _wait_forever(server):
try:
while True:
time.sleep(datetime.timedelta(days=1).total_seconds())
except KeyboardInterrupt:
server.stop(None)
class SelectorGRPCServer(GenericGRPCServer):
@staticmethod
def callback(selector_manager, modyn_config, server):
add_SelectorServicer_to_server(
SelectorGRPCServicer(selector_manager, modyn_config["selector"]["sample_batch_size"]), server
)

def __init__(self, modyn_config: dict) -> None:
self.modyn_config = modyn_config
self.selector_manager = SelectorManager(modyn_config)

def _run_server(bind_address, selector_manager, sample_batch_size):
"""Start a server in a subprocess."""
logging.info(f"[{os.getpid()}] Starting new server.")
callback = functools.partial(SelectorGRPCServer.callback, selector_manager=self.selector_manager)

server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=16,
),
options=[
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
("grpc.so_reuseport", 1),
],
)
add_SelectorServicer_to_server(SelectorGRPCServicer(selector_manager, sample_batch_size), server)
server.add_insecure_port(bind_address)
server.start()
_wait_forever(server)
super().__init__(modyn_config, modyn_config["storage"]["port"], callback)

def __getstate__(self):
for variable_name, value in vars(self).items():
try:
pickle.dumps(value)
except:
print(f'{variable_name} with value {value} is not pickable')

class SelectorServer:
def __init__(self, modyn_config: dict) -> None:
self.modyn_config = modyn_config
self.selector_manager = SelectorManager(modyn_config)
self.sample_batch_size = self.modyn_config["selector"]["sample_batch_size"]
self.workers = []
state = self.__dict__.copy()
if "add_servicer_callback" in state:
del state["add_servicer_callback"]

def run(self) -> None:
port = self.modyn_config["selector"]["port"]
logger.info(f"Starting server. Listening on port {port}")
with _reserve_port(port) as port:
bind_address = "[::]:" + port
for _ in range(64):
worker = mp.Process(
target=_run_server,
args=(bind_address, self.selector_manager, self.sample_batch_size),
)
worker.start()
self.workers.append(worker)
return state

for worker in self.workers:
worker.join()
def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None:
"""Exit the context manager.
Args:
exc_type (type): exception type
exc_val (Exception): exception value
exc_tb (Exception): exception traceback
"""
super().__exit__(exc_type, exc_val, exc_tb)
if (
"cleanup_trigger_samples_after_shutdown" in self.modyn_config["selector"]
and self.modyn_config["selector"]["cleanup_trigger_samples_after_shutdown"]
Expand Down
2 changes: 1 addition & 1 deletion modyn/selector/internal/selector_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _populate_pipeline_if_exists(self, pipeline_id: int) -> None:
return
logging.info(
f"[{os.getpid()}] Instantiating new selector for pipeline {pipeline_id}"
+ " that was in the DB but previously unknown to this process.."
+ " that was in the DB but previously unknown to this process"
)

self._instantiate_selector(pipeline_id, pipeline.num_workers, pipeline.selection_strategy)
Expand Down
2 changes: 1 addition & 1 deletion modyn/selector/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _populate_trigger_if_exists(self, trigger_id: int) -> None:
return

with MetadataDatabaseConnection(self._modyn_config) as database:
trigger: Optional[Trigger] = database.session.get(Trigger, trigger_id, self._pipeline_id)
trigger: Optional[Trigger] = database.session.get(Trigger, (trigger_id, self._pipeline_id))
if trigger is None:
return

Expand Down
18 changes: 14 additions & 4 deletions modyn/selector/selector_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import logging
import multiprocessing as mp
import os
import pathlib

import yaml
from modyn.selector.internal.grpc.selector_server import SelectorServer
from modyn.selector.internal.grpc.selector_server import SelectorGRPCServer

logging.basicConfig(
level=logging.NOTSET,
Expand All @@ -12,6 +14,14 @@
)
logger = logging.getLogger(__name__)

# We need to do this at the top because other dependencies otherwise set fork.
try:
mp.set_start_method("spawn")
except RuntimeError as error:
if mp.get_start_method() != "spawn" and "PYTEST_CURRENT_TEST" not in os.environ:
logger.error("Start method is already set to {}", mp.get_start_method())
raise error


def setup_argparser() -> argparse.ArgumentParser:
parser_ = argparse.ArgumentParser(description="Modyn Selector")
Expand All @@ -35,9 +45,9 @@ def main() -> None:
modyn_config = yaml.safe_load(config_file)

logger.info("Initializing selector server.")
selector = SelectorServer(modyn_config)
logger.info("Starting selector server.")
selector.run()

with SelectorGRPCServer(modyn_config):
pass

logger.info("Selector server returned, exiting.")

Expand Down
Loading

0 comments on commit 5b2c4fe

Please sign in to comment.