From 2b7369360265e8cb716dcb515cc84927634db830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= Date: Mon, 2 Oct 2023 15:59:14 +0200 Subject: [PATCH] refactor to generic grpc server --- modyn/common/grpc/__init__.py | 10 ++ modyn/common/grpc/grpc_helpers.py | 113 ++++++++++++++++++ .../internal/grpc/selector_grpc_servicer.py | 7 +- .../selector/internal/grpc/selector_server.py | 85 ++++--------- modyn/selector/internal/selector_manager.py | 2 +- modyn/selector/selector.py | 2 +- modyn/selector/selector_entrypoint.py | 18 ++- modyn/storage/internal/grpc/grpc_server.py | 91 ++------------ modyn/storage/storage.py | 4 +- .../internal/grpc/test_selector_server.py | 10 +- .../selector/test_selector_entrypoint.py | 10 +- .../storage/internal/grpc/test_grpc_server.py | 6 +- modyn/tests/storage/test_storage.py | 4 +- .../internal/dataset/online_dataset.py | 25 +++- 14 files changed, 218 insertions(+), 169 deletions(-) create mode 100644 modyn/common/grpc/__init__.py create mode 100644 modyn/common/grpc/grpc_helpers.py diff --git a/modyn/common/grpc/__init__.py b/modyn/common/grpc/__init__.py new file mode 100644 index 000000000..6040a0a16 --- /dev/null +++ b/modyn/common/grpc/__init__.py @@ -0,0 +1,10 @@ +""" +This submodule implements functions to run gRPC servers using multiprocessing. +""" +import os + +from .grpc_helpers import GenericGRPCServer # noqa: F401 + +files = os.listdir(os.path.dirname(__file__)) +files.remove("__init__.py") +__all__ = [f[:-3] for f in files if f.endswith(".py")] diff --git a/modyn/common/grpc/grpc_helpers.py b/modyn/common/grpc/grpc_helpers.py new file mode 100644 index 000000000..63ed2852a --- /dev/null +++ b/modyn/common/grpc/grpc_helpers.py @@ -0,0 +1,113 @@ +import contextlib +import datetime +import logging +import multiprocessing as mp +import os +import socket +import time +from concurrent import futures +from typing import Any, Callable + +import grpc +from modyn.utils import MAX_MESSAGE_SIZE + +logger = logging.getLogger(__name__) + +PROCESS_THREAD_WORKERS = 16 +NUM_GPRC_PROCESSES = 64 + + +@contextlib.contextmanager +def reserve_port(port: str): + """Find and reserve a port for all subprocesses to use.""" + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: + raise RuntimeError("Failed to set SO_REUSEPORT.") + sock.bind(("", int(port))) + try: + assert sock.getsockname()[1] == int(port) + yield port + finally: + sock.close() + + +def _wait_forever(server): + try: + while True: + time.sleep(datetime.timedelta(days=1).total_seconds()) + except KeyboardInterrupt: + server.stop(None) + + +def _run_server_worker(bind_address: str, add_servicer_callback: Callable, modyn_config: dict): + """Start a server in a subprocess.""" + logging.info(f"[{os.getpid()}] Starting new gRPC server process.") + + server = grpc.server( + futures.ThreadPoolExecutor( + max_workers=PROCESS_THREAD_WORKERS, + ), + options=[ + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), + ("grpc.so_reuseport", 1), + ], + ) + + add_servicer_callback(modyn_config, server) + server.add_insecure_port(bind_address) + server.start() + _wait_forever(server) + + +class GenericGRPCServer: + def __init__(self, modyn_config: dict, port: str, add_servicer_callback: Callable) -> None: + """Initialize the GRPC server. + + Args: + TODO + """ + self.port = port + self.modyn_config = modyn_config + self.add_servicer_callback = add_servicer_callback + self.workers = [] + + def __enter__(self) -> Any: + """Enter the context manager. + + Returns: + grpc.Server: GRPC server + """ + logger.info(f"[{os.getpid()}] Starting server. Listening on port {self.port}") + with reserve_port(self.port) as port: + bind_address = "[::]:" + port + for _ in range(NUM_GPRC_PROCESSES): + worker = mp.Process( + target=_run_server_worker, + args=(bind_address, self.add_servicer_callback), + ) + worker.start() + self.workers.append(worker) + + return self + + def __getstate__(self): + state = self.__dict__.copy() + del state["add_servicer_callback"] + return state + + def wait_for_termination(self) -> None: + for worker in self.workers: + worker.join() + + def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None: + """Exit the context manager. + + Args: + exc_type (type): exception type + exc_val (Exception): exception value + exc_tb (Exception): exception traceback + """ + self.wait_for_termination() + del self.workers diff --git a/modyn/selector/internal/grpc/selector_grpc_servicer.py b/modyn/selector/internal/grpc/selector_grpc_servicer.py index a1bea06ec..0db6cf8f6 100644 --- a/modyn/selector/internal/grpc/selector_grpc_servicer.py +++ b/modyn/selector/internal/grpc/selector_grpc_servicer.py @@ -1,5 +1,7 @@ import json import logging +import os +import threading from typing import Iterable import grpc @@ -59,8 +61,11 @@ def get_sample_keys_and_weights( # pylint: disable-next=unused-argument request.worker_id, request.partition_id, ) + tid = threading.get_native_id() + pid = os.getpid() + logger.info( - f"[Pipeline {pipeline_id}]: Fetching samples for trigger id {trigger_id}" + f"[{pid}][{tid}][Pipeline {pipeline_id}]: Fetching samples for trigger id {trigger_id}" + f" and worker id {worker_id} and partition id {partition_id}" ) diff --git a/modyn/selector/internal/grpc/selector_server.py b/modyn/selector/internal/grpc/selector_server.py index ead69a692..2923ef47d 100644 --- a/modyn/selector/internal/grpc/selector_server.py +++ b/modyn/selector/internal/grpc/selector_server.py @@ -1,5 +1,6 @@ import contextlib import datetime +import functools import logging import multiprocessing as mp import os @@ -8,6 +9,7 @@ from concurrent import futures import grpc +from modyn.common.grpc import GenericGRPCServer from modyn.selector.internal.grpc.generated.selector_pb2_grpc import add_SelectorServicer_to_server # noqa: E402, E501 from modyn.selector.internal.grpc.selector_grpc_servicer import SelectorGRPCServicer from modyn.selector.internal.selector_manager import SelectorManager @@ -16,72 +18,37 @@ logger = logging.getLogger(__name__) -@contextlib.contextmanager -def _reserve_port(port: str): - """Find and reserve a port for all subprocesses to use.""" - sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: - raise RuntimeError("Failed to set SO_REUSEPORT.") - sock.bind(("", int(port))) - try: - assert sock.getsockname()[1] == int(port) - yield port - finally: - sock.close() +class SelectorGRPCServer(GenericGRPCServer): + @staticmethod + def callback(selector_manager, modyn_config, server): + add_SelectorServicer_to_server( + SelectorGRPCServicer(selector_manager, modyn_config["selector"]["sample_batch_size"]), server + ) - -def _wait_forever(server): - try: - while True: - time.sleep(datetime.timedelta(days=1).total_seconds()) - except KeyboardInterrupt: - server.stop(None) - - -def _run_server(bind_address, selector_manager, sample_batch_size): - """Start a server in a subprocess.""" - logging.info(f"[{os.getpid()}] Starting new server.") - - server = grpc.server( - futures.ThreadPoolExecutor( - max_workers=16, - ), - options=[ - ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), - ("grpc.so_reuseport", 1), - ], - ) - add_SelectorServicer_to_server(SelectorGRPCServicer(selector_manager, sample_batch_size), server) - server.add_insecure_port(bind_address) - server.start() - _wait_forever(server) - - -class SelectorServer: def __init__(self, modyn_config: dict) -> None: self.modyn_config = modyn_config self.selector_manager = SelectorManager(modyn_config) - self.sample_batch_size = self.modyn_config["selector"]["sample_batch_size"] - self.workers = [] - def run(self) -> None: - port = self.modyn_config["selector"]["port"] - logger.info(f"Starting server. Listening on port {port}") - with _reserve_port(port) as port: - bind_address = "[::]:" + port - for _ in range(64): - worker = mp.Process( - target=_run_server, - args=(bind_address, self.selector_manager, self.sample_batch_size), - ) - worker.start() - self.workers.append(worker) + callback = functools.partial(SelectorGRPCServer.callback, selector_manager=self.selector_manager) + + super().__init__(modyn_config, modyn_config["storage"]["port"], callback) + + def __getstate__(self): + state = self.__dict__.copy() + if "add_servicer_callback" in state: + del state["add_servicer_callback"] + + return state - for worker in self.workers: - worker.join() + def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None: + """Exit the context manager. + Args: + exc_type (type): exception type + exc_val (Exception): exception value + exc_tb (Exception): exception traceback + """ + super().__exit__(exc_type, exc_val, exc_tb) if ( "cleanup_trigger_samples_after_shutdown" in self.modyn_config["selector"] and self.modyn_config["selector"]["cleanup_trigger_samples_after_shutdown"] diff --git a/modyn/selector/internal/selector_manager.py b/modyn/selector/internal/selector_manager.py index ce271fec7..5994eda5b 100644 --- a/modyn/selector/internal/selector_manager.py +++ b/modyn/selector/internal/selector_manager.py @@ -70,7 +70,7 @@ def _populate_pipeline_if_exists(self, pipeline_id: int) -> None: return logging.info( f"[{os.getpid()}] Instantiating new selector for pipeline {pipeline_id}" - + " that was in the DB but previously unknown to this process.." + + " that was in the DB but previously unknown to this process" ) self._instantiate_selector(pipeline_id, pipeline.num_workers, pipeline.selection_strategy) diff --git a/modyn/selector/selector.py b/modyn/selector/selector.py index b894b0626..3050bf3ae 100644 --- a/modyn/selector/selector.py +++ b/modyn/selector/selector.py @@ -40,7 +40,7 @@ def _populate_trigger_if_exists(self, trigger_id: int) -> None: return with MetadataDatabaseConnection(self._modyn_config) as database: - trigger: Optional[Trigger] = database.session.get(Trigger, trigger_id, self._pipeline_id) + trigger: Optional[Trigger] = database.session.get(Trigger, (trigger_id, self._pipeline_id)) if trigger is None: return diff --git a/modyn/selector/selector_entrypoint.py b/modyn/selector/selector_entrypoint.py index 152b4c125..0795819c1 100644 --- a/modyn/selector/selector_entrypoint.py +++ b/modyn/selector/selector_entrypoint.py @@ -1,9 +1,11 @@ import argparse import logging +import multiprocessing as mp +import os import pathlib import yaml -from modyn.selector.internal.grpc.selector_server import SelectorServer +from modyn.selector.internal.grpc.selector_server import SelectorGRPCServer logging.basicConfig( level=logging.NOTSET, @@ -12,6 +14,14 @@ ) logger = logging.getLogger(__name__) +# We need to do this at the top because other dependencies otherwise set fork. +try: + mp.set_start_method("spawn") +except RuntimeError as error: + if mp.get_start_method() != "spawn" and "PYTEST_CURRENT_TEST" not in os.environ: + logger.error("Start method is already set to {}", mp.get_start_method()) + raise error + def setup_argparser() -> argparse.ArgumentParser: parser_ = argparse.ArgumentParser(description="Modyn Selector") @@ -35,9 +45,9 @@ def main() -> None: modyn_config = yaml.safe_load(config_file) logger.info("Initializing selector server.") - selector = SelectorServer(modyn_config) - logger.info("Starting selector server.") - selector.run() + + with SelectorGRPCServer(modyn_config): + pass logger.info("Selector server returned, exiting.") diff --git a/modyn/storage/internal/grpc/grpc_server.py b/modyn/storage/internal/grpc/grpc_server.py index 7f14520a3..5aa596984 100644 --- a/modyn/storage/internal/grpc/grpc_server.py +++ b/modyn/storage/internal/grpc/grpc_server.py @@ -11,6 +11,7 @@ from typing import Any import grpc +from modyn.common.grpc import GenericGRPCServer from modyn.storage.internal.grpc.generated.storage_pb2_grpc import add_StorageServicer_to_server from modyn.storage.internal.grpc.storage_grpc_servicer import StorageGRPCServicer from modyn.utils import MAX_MESSAGE_SIZE @@ -18,95 +19,17 @@ logger = logging.getLogger(__name__) -@contextlib.contextmanager -def _reserve_port(port: str): - """Find and reserve a port for all subprocesses to use.""" - sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: - raise RuntimeError("Failed to set SO_REUSEPORT.") - sock.bind(("", int(port))) - try: - assert sock.getsockname()[1] == int(port) - yield port - finally: - sock.close() - - -def _wait_forever(server): - try: - while True: - time.sleep(datetime.timedelta(days=1).total_seconds()) - except KeyboardInterrupt: - server.stop(None) - - -def _run_server(bind_address, modyn_config): - """Start a server in a subprocess.""" - logging.info(f"[{os.getpid()}] Starting new server.") - - server = grpc.server( - futures.ThreadPoolExecutor( - max_workers=16, - ), - options=[ - ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), - ("grpc.so_reuseport", 1), - ], - ) - add_StorageServicer_to_server(StorageGRPCServicer(modyn_config), server) - server.add_insecure_port(bind_address) - server.start() - _wait_forever(server) - - -class GRPCServer: +class StorageGRPCServer(GenericGRPCServer): """GRPC server context manager.""" + @staticmethod + def callback(modyn_config, server): + add_StorageServicer_to_server(StorageGRPCServicer(modyn_config), server) + def __init__(self, modyn_config: dict) -> None: """Initialize the GRPC server. Args: modyn_config (dict): Configuration of the storage module. """ - self.modyn_config = modyn_config - self.workers = [] - - def __enter__(self) -> Any: - """Enter the context manager. - - Returns: - grpc.Server: GRPC server - """ - port = self.modyn_config["storage"]["port"] - logger.info(f"Starting server. Listening on port {port}") - with _reserve_port(port) as port: - bind_address = "[::]:" + port - for _ in range(64): - worker = mp.Process( - target=_run_server, - args=( - bind_address, - self.modyn_config, - ), - ) - worker.start() - self.workers.append(worker) - - return self - - def wait_for_termination(self) -> None: - for worker in self.workers: - worker.join() - - def __exit__(self, exc_type: type, exc_val: Exception, exc_tb: Exception) -> None: - """Exit the context manager. - - Args: - exc_type (type): exception type - exc_val (Exception): exception value - exc_tb (Exception): exception traceback - """ - self.wait_for_termination() - del self.workers + super().__init__(modyn_config, modyn_config["storage"]["port"], StorageGRPCServicer.callback) diff --git a/modyn/storage/storage.py b/modyn/storage/storage.py index 17cba3b48..c2e8e3176 100644 --- a/modyn/storage/storage.py +++ b/modyn/storage/storage.py @@ -14,7 +14,7 @@ from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection from modyn.storage.internal.file_watcher.new_file_watcher_watch_dog import run_watcher_watch_dog -from modyn.storage.internal.grpc.grpc_server import GRPCServer +from modyn.storage.internal.grpc.grpc_server import StorageGRPCServer from modyn.utils import validate_yaml logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ def run(self) -> None: watchdog.start() #  Start the storage grpc server. - with GRPCServer(self.modyn_config) as server: + with StorageGRPCServer(self.modyn_config) as server: server.wait_for_termination() should_stop.value = True # type: ignore # See https://github.com/python/typeshed/issues/8799 diff --git a/modyn/tests/selector/internal/grpc/test_selector_server.py b/modyn/tests/selector/internal/grpc/test_selector_server.py index 47c3e73e2..73e416ba1 100644 --- a/modyn/tests/selector/internal/grpc/test_selector_server.py +++ b/modyn/tests/selector/internal/grpc/test_selector_server.py @@ -3,7 +3,7 @@ from unittest import mock from unittest.mock import MagicMock, patch -from modyn.selector.internal.grpc.selector_server import SelectorServer +from modyn.selector.internal.grpc.selector_server import SelectorGRPCServer from modyn.selector.internal.selector_manager import SelectorManager @@ -27,7 +27,7 @@ def test_init(): with tempfile.TemporaryDirectory() as tmp_dir: config = get_modyn_config() config["selector"]["trigger_sample_directory"] = tmp_dir - grpc_server = SelectorServer(config) + grpc_server = SelectorGRPCServer(config) assert grpc_server.modyn_config == config @@ -36,7 +36,7 @@ def test_prepare_server(): with tempfile.TemporaryDirectory() as tmp_dir: config = get_modyn_config() config["selector"]["trigger_sample_directory"] = tmp_dir - grpc_server = SelectorServer(config) + grpc_server = SelectorGRPCServer(config) mock_add = mock.Mock() grpc_server._add_servicer_to_server_func = mock_add @@ -46,12 +46,12 @@ def test_prepare_server(): @patch.object(SelectorManager, "init_metadata_db", noop_init_metadata_db) -@patch.object(SelectorServer, "prepare_server") +@patch.object(SelectorGRPCServer, "prepare_server") def test_run(test_prepare_server: MagicMock): with tempfile.TemporaryDirectory() as tmp_dir: config = get_modyn_config() config["selector"]["trigger_sample_directory"] = tmp_dir - grpc_server = SelectorServer(config) + grpc_server = SelectorGRPCServer(config) mock_start = mock.Mock() mock_wait = mock.Mock() diff --git a/modyn/tests/selector/test_selector_entrypoint.py b/modyn/tests/selector/test_selector_entrypoint.py index 1b7083efe..33f6d46b7 100644 --- a/modyn/tests/selector/test_selector_entrypoint.py +++ b/modyn/tests/selector/test_selector_entrypoint.py @@ -6,7 +6,7 @@ import pathlib from unittest.mock import patch -from modyn.selector.internal.grpc.selector_server import SelectorServer +from modyn.selector.internal.grpc.selector_server import SelectorGRPCServer SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) @@ -23,15 +23,15 @@ def noop_run(self) -> None: pass -@patch.object(SelectorServer, "__init__", noop_constructor_mock) -@patch.object(SelectorServer, "run", noop_run) +@patch.object(SelectorGRPCServer, "__init__", noop_constructor_mock) +@patch.object(SelectorGRPCServer, "run", noop_run) def test_trainer_server_script_runs(script_runner): ret = script_runner.run("_modyn_selector", str(EXAMPLE_SYSTEM_CONFIG)) assert ret.success -@patch.object(SelectorServer, "__init__", noop_constructor_mock) -@patch.object(SelectorServer, "run", noop_run) +@patch.object(SelectorGRPCServer, "__init__", noop_constructor_mock) +@patch.object(SelectorGRPCServer, "run", noop_run) def test_trainer_server_fails_on_non_existing_system_config(script_runner): ret = script_runner.run("_modyn_selector", str(NO_FILE)) assert not ret.success diff --git a/modyn/tests/storage/internal/grpc/test_grpc_server.py b/modyn/tests/storage/internal/grpc/test_grpc_server.py index 5f7795d11..b5f3817f1 100644 --- a/modyn/tests/storage/internal/grpc/test_grpc_server.py +++ b/modyn/tests/storage/internal/grpc/test_grpc_server.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument from unittest.mock import patch -from modyn.storage.internal.grpc.grpc_server import GRPCServer +from modyn.storage.internal.grpc.grpc_server import StorageGRPCServer def get_modyn_config(): @@ -9,11 +9,11 @@ def get_modyn_config(): def test_init(): - grpc_server = GRPCServer(get_modyn_config()) + grpc_server = StorageGRPCServer(get_modyn_config()) assert grpc_server.modyn_config == get_modyn_config() @patch("modyn.storage.internal.grpc.grpc_server.add_StorageServicer_to_server", return_value=None) def test_enter(mock_add_storage_servicer_to_server): - with GRPCServer(get_modyn_config()) as grpc_server: + with StorageGRPCServer(get_modyn_config()) as grpc_server: assert grpc_server is not None diff --git a/modyn/tests/storage/test_storage.py b/modyn/tests/storage/test_storage.py index 5ba24caa8..e0b1c6806 100644 --- a/modyn/tests/storage/test_storage.py +++ b/modyn/tests/storage/test_storage.py @@ -4,7 +4,7 @@ import pytest from modyn.storage.internal.database.storage_database_connection import StorageDatabaseConnection -from modyn.storage.internal.grpc.grpc_server import GRPCServer +from modyn.storage.internal.grpc.grpc_server import StorageGRPCServer from modyn.storage.storage import Storage database_path = pathlib.Path(os.path.abspath(__file__)).parent / "test_storage.db" @@ -76,7 +76,7 @@ def wait_for_termination(self, *args, **kwargs): # pylint: disable=unused-argum return -class MockGRPCServer(GRPCServer): +class MockGRPCServer(StorageGRPCServer): def __enter__(self): return MockGRPCInstance() diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 7b07c11eb..48c9f8334 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -1,4 +1,5 @@ import contextlib +import functools import json import logging import os @@ -154,6 +155,7 @@ def _get_data( partition_valid_until: Optional[dict], partition_locks: Optional[dict], partition_signals: Optional[dict], + callback: Optional[Callable], ) -> None: get_data_log = {} self._sw.start(f"GetKeysAndWeightsPart{partition_id}", overwrite=True) @@ -195,6 +197,9 @@ def _get_data( with partition_locks[partition_id]: partition_valid[partition_id] = True + if callback is not None: + callback() + def _get_transformed_data_tuple( self, key: int, sample: bytes, label: int, weight: Optional[float] ) -> Optional[Tuple]: @@ -226,7 +231,7 @@ def _persist_log(self, worker_id: int) -> None: with open(log_file, "w", encoding="utf-8") as logfile: json.dump(self._log, logfile) - def _prefetch_partition(self, worker_id: int) -> None: + def _prefetch_partition(self, worker_id: int, num_additional_prefetches: int) -> None: if self._prefetched_partitions < 1 or self._next_partition_to_fetch >= self._num_partitions: return # Prefetching disabled or nothing more to prefetch @@ -248,6 +253,12 @@ def _prefetch_partition(self, worker_id: int) -> None: self._partition_locks[self._next_partition_to_fetch] ) + def potential_callback(): + self._info("Prefetch callback called.") + self._prefetch_partition(worker_id, num_additional_prefetches - 1) + + callback = None if num_additional_prefetches == 0 else potential_callback + self._data_threads[self._next_partition_to_fetch] = threading.Thread( target=self._get_data, args=( @@ -258,6 +269,7 @@ def _prefetch_partition(self, worker_id: int) -> None: self._partition_valid_until, self._partition_locks, self._partition_signals, + callback, ), ) @@ -278,6 +290,9 @@ def _fetch_partition_noprefetch( yield container["keys"][idx], container["data"][idx], container["labels"][idx], container["weights"][idx] def _is_partition_fetched(self, partition_id: int) -> bool: + if partition_id not in self._partition_locks or partition_id not in self._partition_valid: + return False + with self._partition_locks[partition_id]: return self._partition_valid[partition_id] @@ -304,7 +319,6 @@ def _wait_for_new_partition_data(self, partition_id: int) -> None: def prefetched_partition_generator( self, worker_id: int, partition_id: int ) -> Iterator[tuple[int, bytes, int, Optional[float]]]: - assert self._pref_started[partition_id], f"Prefetching for partition {partition_id} has not been started" last_idx = -1 while not self._is_partition_fetched(partition_id): @@ -322,7 +336,14 @@ def prefetched_partition_generator( max_idx = self._partition_max_index(partition_id) yield from self._get_partition_data(last_idx, max_idx, partition_id) + def start_prefetching(self, worker_id: int) -> None: + if self._prefetched_partitions < 1: + return + + self._prefetch_partition(worker_id, self._prefetched_partitions - 1) + def all_partition_generator(self, worker_id: int) -> Iterator[tuple[int, bytes, int, Optional[float]]]: + self.start_prefetching(worker_id) for _ in range(self._prefetched_partitions): self._prefetch_partition(worker_id)