diff --git a/modyn/metadata_database/metadata_database_connection.py b/modyn/metadata_database/metadata_database_connection.py index ac337bc1f..311a8c3f0 100644 --- a/modyn/metadata_database/metadata_database_connection.py +++ b/modyn/metadata_database/metadata_database_connection.py @@ -67,16 +67,17 @@ def create_tables(self) -> None: """ MetadataBase.metadata.create_all(self.engine) - def register_pipeline(self, num_workers: int) -> int: + def register_pipeline(self, num_workers: int, selection_strategy: str) -> int: """Register a new pipeline in the database. Args: num_workers (int): Number of workers in the pipeline. + selection_strategy (str): The selection strategy to use Returns: int: Id of the newly created pipeline. """ - pipeline = Pipeline(num_workers=num_workers) + pipeline = Pipeline(num_workers=num_workers, selection_strategy=selection_strategy) self.session.add(pipeline) self.session.commit() pipeline_id = pipeline.pipeline_id diff --git a/modyn/metadata_database/models/pipelines.py b/modyn/metadata_database/models/pipelines.py index 4094b3f95..cd8370c7e 100644 --- a/modyn/metadata_database/models/pipelines.py +++ b/modyn/metadata_database/models/pipelines.py @@ -1,7 +1,7 @@ """Pipeline model.""" from modyn.metadata_database.metadata_base import MetadataBase -from sqlalchemy import Column, Integer +from sqlalchemy import Column, Integer, Text class Pipeline(MetadataBase): @@ -12,6 +12,7 @@ class Pipeline(MetadataBase): __table_args__ = {"extend_existing": True} pipeline_id = Column("pipeline_id", Integer, primary_key=True) num_workers = Column("num_workers", Integer, nullable=False) + selection_strategy = Column("selection_strategy", Text, nullable=False) def __repr__(self) -> str: """Return string representation.""" diff --git a/modyn/selector/internal/grpc/selector_server.py b/modyn/selector/internal/grpc/selector_server.py index 3ca0fc4d9..ead69a692 100644 --- a/modyn/selector/internal/grpc/selector_server.py +++ b/modyn/selector/internal/grpc/selector_server.py @@ -1,4 +1,10 @@ +import contextlib +import datetime import logging +import multiprocessing as mp +import os +import socket +import time from concurrent import futures import grpc @@ -10,32 +16,72 @@ logger = logging.getLogger(__name__) +@contextlib.contextmanager +def _reserve_port(port: str): + """Find and reserve a port for all subprocesses to use.""" + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: + raise RuntimeError("Failed to set SO_REUSEPORT.") + sock.bind(("", int(port))) + try: + assert sock.getsockname()[1] == int(port) + yield port + finally: + sock.close() + + +def _wait_forever(server): + try: + while True: + time.sleep(datetime.timedelta(days=1).total_seconds()) + except KeyboardInterrupt: + server.stop(None) + + +def _run_server(bind_address, selector_manager, sample_batch_size): + """Start a server in a subprocess.""" + logging.info(f"[{os.getpid()}] Starting new server.") + + server = grpc.server( + futures.ThreadPoolExecutor( + max_workers=16, + ), + options=[ + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), + ("grpc.so_reuseport", 1), + ], + ) + add_SelectorServicer_to_server(SelectorGRPCServicer(selector_manager, sample_batch_size), server) + server.add_insecure_port(bind_address) + server.start() + _wait_forever(server) + + class SelectorServer: def __init__(self, modyn_config: dict) -> None: self.modyn_config = modyn_config self.selector_manager = SelectorManager(modyn_config) - self.grpc_servicer = SelectorGRPCServicer( - self.selector_manager, self.modyn_config["selector"]["sample_batch_size"] - ) - self._add_servicer_to_server_func = add_SelectorServicer_to_server - - def prepare_server(self) -> grpc.server: - server = grpc.server( - futures.ThreadPoolExecutor(max_workers=64), - options=[ - ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), - ], - ) - self._add_servicer_to_server_func(self.grpc_servicer, server) - return server + self.sample_batch_size = self.modyn_config["selector"]["sample_batch_size"] + self.workers = [] def run(self) -> None: - server = self.prepare_server() - logger.info(f"Starting server. Listening on port {self.modyn_config['selector']['port']}.") - server.add_insecure_port("[::]:" + self.modyn_config["selector"]["port"]) - server.start() - server.wait_for_termination() + port = self.modyn_config["selector"]["port"] + logger.info(f"Starting server. Listening on port {port}") + with _reserve_port(port) as port: + bind_address = "[::]:" + port + for _ in range(64): + worker = mp.Process( + target=_run_server, + args=(bind_address, self.selector_manager, self.sample_batch_size), + ) + worker.start() + self.workers.append(worker) + + for worker in self.workers: + worker.join() + if ( "cleanup_trigger_samples_after_shutdown" in self.modyn_config["selector"] and self.modyn_config["selector"]["cleanup_trigger_samples_after_shutdown"] diff --git a/modyn/selector/internal/selector_manager.py b/modyn/selector/internal/selector_manager.py index 51fa6bf89..1e2b56571 100644 --- a/modyn/selector/internal/selector_manager.py +++ b/modyn/selector/internal/selector_manager.py @@ -4,10 +4,12 @@ import logging import os import shutil +from multiprocessing import Lock, Manager from pathlib import Path -from threading import Lock +from typing import Optional from modyn.metadata_database.metadata_database_connection import MetadataDatabaseConnection +from modyn.metadata_database.models.pipelines import Pipeline from modyn.selector.internal.selector_strategies.abstract_selection_strategy import AbstractSelectionStrategy from modyn.selector.selector import Selector from modyn.utils.utils import dynamic_module_import, is_directory_writable @@ -18,9 +20,10 @@ class SelectorManager: def __init__(self, modyn_config: dict) -> None: self._modyn_config = modyn_config + self._manager = Manager() self._selectors: dict[int, Selector] = {} - self._selector_locks: dict[int, Lock] = {} - self._next_pipeline_lock = Lock() + self._selector_locks: dict[int, Lock] = self._manager.dict() + self._next_pipeline_lock = self._manager.Lock() self._selector_cache_size = self._modyn_config["selector"]["keys_in_selector_cache"] self.init_metadata_db() @@ -57,6 +60,27 @@ def _init_trigger_sample_directory(self) -> None: + f"Directory info: {os.stat(trigger_sample_directory)}" ) + def _populate_pipeline_if_exists(self, pipeline_id: int) -> None: + if pipeline_id in self._selectors: + return + + with MetadataDatabaseConnection(self._modyn_config) as database: + pipeline: Optional[Pipeline] = database.session.get(Pipeline, pipeline_id) + if pipeline is None: + return + logging.info( + f"[{os.getpid()}] Instantiating new selector for pipeline {pipeline_id}" + + " that was in the DB but previously unknown to this process.." + ) + + self._instantiate_selector(pipeline_id, pipeline.num_workers, pipeline.selection_strategy) + + def _instantiate_selector(self, pipeline_id: int, num_workers: int, selection_strategy: str) -> None: + assert pipeline_id in self._selector_locks, f"Trying to register pipeline {pipeline_id} without existing lock!" + selection_strategy = self._instantiate_strategy(json.loads(selection_strategy), pipeline_id) + selector = Selector(selection_strategy, pipeline_id, num_workers, self._selector_cache_size) + self._selectors[pipeline_id] = selector + def register_pipeline(self, num_workers: int, selection_strategy: str) -> int: """ Registers a new pipeline at the Selector. @@ -70,12 +94,11 @@ def register_pipeline(self, num_workers: int, selection_strategy: str) -> int: with self._next_pipeline_lock: with MetadataDatabaseConnection(self._modyn_config) as database: - pipeline_id = database.register_pipeline(num_workers) + pipeline_id = database.register_pipeline(num_workers, selection_strategy) + + self._selector_locks[pipeline_id] = self._manager.Lock() + self._instantiate_selector(pipeline_id, num_workers, selection_strategy) - selection_strategy = self._instantiate_strategy(json.loads(selection_strategy), pipeline_id) - selector = Selector(selection_strategy, pipeline_id, num_workers, self._selector_cache_size) - self._selectors[pipeline_id] = selector - self._selector_locks[pipeline_id] = Lock() return pipeline_id def get_sample_keys_and_weights( @@ -92,6 +115,8 @@ def get_sample_keys_and_weights( List of tuples for the samples to be returned to that particular worker. The first index of the tuple will be the key, and the second index will be that sample's weight. """ + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Requested keys from pipeline {pipeline_id} which does not exist!") @@ -104,6 +129,8 @@ def get_sample_keys_and_weights( def inform_data( self, pipeline_id: int, keys: list[int], timestamps: list[int], labels: list[int] ) -> dict[str, object]: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Informing pipeline {pipeline_id} of data. Pipeline does not exist!") @@ -113,6 +140,8 @@ def inform_data( def inform_data_and_trigger( self, pipeline_id: int, keys: list[int], timestamps: list[int], labels: list[int] ) -> tuple[int, dict[str, object]]: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Informing pipeline {pipeline_id} of data and triggering. Pipeline does not exist!") @@ -120,30 +149,40 @@ def inform_data_and_trigger( return self._selectors[pipeline_id].inform_data_and_trigger(keys, timestamps, labels) def get_number_of_samples(self, pipeline_id: int, trigger_id: int) -> int: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Requested number of samples from pipeline {pipeline_id} which does not exist!") return self._selectors[pipeline_id].get_number_of_samples(trigger_id) def get_status_bar_scale(self, pipeline_id: int) -> int: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Requested status bar scale from pipeline {pipeline_id} which does not exist!") return self._selectors[pipeline_id].get_status_bar_scale() def get_number_of_partitions(self, pipeline_id: int, trigger_id: int) -> int: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Requested number of partitions from pipeline {pipeline_id} which does not exist!") return self._selectors[pipeline_id].get_number_of_partitions(trigger_id) def get_available_labels(self, pipeline_id: int) -> list[int]: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Requested available labels from pipeline {pipeline_id} which does not exist!") return self._selectors[pipeline_id].get_available_labels() def uses_weights(self, pipeline_id: int) -> bool: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Requested whether the pipeline {pipeline_id} uses weights but it does not exist!") @@ -169,6 +208,8 @@ def _instantiate_strategy(self, selection_strategy: dict, pipeline_id: int) -> A return strategy_handler(config, self._modyn_config, pipeline_id, maximum_keys_in_memory) def get_selection_strategy_remote(self, pipeline_id: int) -> tuple[bool, str, dict]: + self._populate_pipeline_if_exists(pipeline_id) + if pipeline_id not in self._selectors: raise ValueError(f"Requested selection strategy for pipeline {pipeline_id} which does not exist!") diff --git a/modyn/storage/internal/grpc/grpc_server.py b/modyn/storage/internal/grpc/grpc_server.py index 0a76d6652..7f14520a3 100644 --- a/modyn/storage/internal/grpc/grpc_server.py +++ b/modyn/storage/internal/grpc/grpc_server.py @@ -1,7 +1,14 @@ """GRPC server context manager.""" +import contextlib +import datetime import logging +import multiprocessing as mp +import os +import socket +import time from concurrent import futures +from typing import Any import grpc from modyn.storage.internal.grpc.generated.storage_pb2_grpc import add_StorageServicer_to_server @@ -11,6 +18,49 @@ logger = logging.getLogger(__name__) +@contextlib.contextmanager +def _reserve_port(port: str): + """Find and reserve a port for all subprocesses to use.""" + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: + raise RuntimeError("Failed to set SO_REUSEPORT.") + sock.bind(("", int(port))) + try: + assert sock.getsockname()[1] == int(port) + yield port + finally: + sock.close() + + +def _wait_forever(server): + try: + while True: + time.sleep(datetime.timedelta(days=1).total_seconds()) + except KeyboardInterrupt: + server.stop(None) + + +def _run_server(bind_address, modyn_config): + """Start a server in a subprocess.""" + logging.info(f"[{os.getpid()}] Starting new server.") + + server = grpc.server( + futures.ThreadPoolExecutor( + max_workers=16, + ), + options=[ + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), + ("grpc.so_reuseport", 1), + ], + ) + add_StorageServicer_to_server(StorageGRPCServicer(modyn_config), server) + server.add_insecure_port(bind_address) + server.start() + _wait_forever(server) + + class GRPCServer: """GRPC server context manager.""" @@ -21,28 +71,34 @@ def __init__(self, modyn_config: dict) -> None: modyn_config (dict): Configuration of the storage module. """ self.modyn_config = modyn_config - self.server = grpc.server( - futures.ThreadPoolExecutor( - max_workers=64, - ), - options=[ - ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), - ], - ) - - def __enter__(self) -> grpc.Server: + self.workers = [] + + def __enter__(self) -> Any: """Enter the context manager. Returns: grpc.Server: GRPC server """ - add_StorageServicer_to_server(StorageGRPCServicer(self.modyn_config), self.server) port = self.modyn_config["storage"]["port"] logger.info(f"Starting server. Listening on port {port}") - self.server.add_insecure_port("[::]:" + port) - self.server.start() - return self.server + with _reserve_port(port) as port: + bind_address = "[::]:" + port + for _ in range(64): + worker = mp.Process( + target=_run_server, + args=( + bind_address, + self.modyn_config, + ), + ) + worker.start() + self.workers.append(worker) + + return self + + def wait_for_termination(self) -> None: + for worker in self.workers: + worker.join() def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None: """Exit the context manager. @@ -52,4 +108,5 @@ def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> Non exc_val (Exception): exception value exc_tb (Exception): exception traceback """ - self.server.stop(0) + self.wait_for_termination() + del self.workers diff --git a/modyn/storage/internal/grpc/storage_grpc_servicer.py b/modyn/storage/internal/grpc/storage_grpc_servicer.py index 219eb5c65..f3c8c8936 100644 --- a/modyn/storage/internal/grpc/storage_grpc_servicer.py +++ b/modyn/storage/internal/grpc/storage_grpc_servicer.py @@ -1,9 +1,12 @@ """Storage GRPC servicer.""" import logging +import os +import threading from typing import Iterable, Tuple import grpc +from modyn.common.benchmark.stopwatch import Stopwatch from modyn.storage.internal.database.models import Dataset, File, Sample from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection from modyn.storage.internal.database.storage_database_utils import get_file_wrapper, get_filesystem_wrapper @@ -64,6 +67,9 @@ def Get(self, request: GetRequest, context: grpc.ServicerContext) -> Iterable[Ge Yields: Iterator[Iterable[GetResponse]]: Response containing the data for the given keys. """ + tid = threading.get_native_id() + pid = os.getpid() + logger.info(f"[{pid}][{tid}] Received request for {len(request.keys)} items.") with StorageDatabaseConnection(self.modyn_config) as database: session = database.session @@ -73,12 +79,16 @@ def Get(self, request: GetRequest, context: grpc.ServicerContext) -> Iterable[Ge yield GetResponse() return + stopw = Stopwatch() + stopw.start("GetSamples") samples: list[Sample] = ( session.query(Sample) .filter(and_(Sample.sample_id.in_(request.keys), Sample.dataset_id == dataset.dataset_id)) .order_by(Sample.file_id) .all() ) + samples_time = stopw.stop() + logger.info(f"[{pid}][{tid}] Getting samples took {samples_time / 1000}s.") if len(samples) == 0: logger.error("No samples found in the database.") diff --git a/modyn/tests/model_storage/internal/grpc/test_model_storage.database b/modyn/tests/model_storage/internal/grpc/test_model_storage.database new file mode 100644 index 000000000..0902c438a Binary files /dev/null and b/modyn/tests/model_storage/internal/grpc/test_model_storage.database differ