From ad33db6ed8b94f4773429a306bce4b0b047d9aec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= Date: Wed, 10 May 2023 22:57:47 +0200 Subject: [PATCH 1/2] implement async prefetch --- .../internal/dataset/online_dataset.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 6b3551681..fd40e33b5 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -1,5 +1,6 @@ import gc import logging +import threading from inspect import isfunction from typing import Any, Callable, Generator, Optional @@ -138,11 +139,41 @@ def _info(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover def _debug(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover logger.debug(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") - def _get_data(self, worker_id: int, partition_id: int) -> tuple[list[int], list[bytes], list[int]]: + def _get_data(self, data_container: dict, worker_id: int, partition_id: int): self._info("Getting keys from selector", worker_id) keys = self._get_keys_from_selector(worker_id, partition_id) self._info("Getting data from storage", worker_id) data, labels = self._get_data_from_storage(keys) + + data_container["data"] = data + data_container["labels"] = labels + data_container["keys"] = keys + + def _run_get_data_thread(self, worker_id: int, partition_id: int): + assert "_data_thread" not in self.__dict__.keys() or self._data_thread is None + self._thread_data_container = {} + self._data_thread = threading.Thread( + target=self._get_data, args=(self._thread_data_container, worker_id, partition_id) + ) + self._data_thread.start() + + def _wait_for_data_thread(self, worker_id: int) -> tuple[list[int], list[bytes], list[int]]: + assert self._data_thread is not None + self._info("Joining data thread.", worker_id) + self._data_thread.join() + self._data_thread = None + + assert "data" in self._thread_data_container + assert "labels" in self._thread_data_container + assert "keys" in self._thread_data_container + + keys, data, labels = ( + self._thread_data_container["keys"], + self._thread_data_container["data"], + self._thread_data_container["labels"], + ) + del self._thread_data_container + return keys, data, labels def _get_num_data_partitions(self) -> int: @@ -179,18 +210,19 @@ def __iter__(self) -> Generator: self._num_partitions = self._get_num_data_partitions() self._info(f"Total number of partitions will be {self._num_partitions}", worker_id) - keys, data, labels = self._get_data(worker_id=worker_id, partition_id=0) + self._run_get_data_thread(worker_id=worker_id, partition_id=0) + keys, data, labels = self._wait_for_data_thread(worker_id) for partition in range(self._num_partitions): num_samples_on_this_partition = len(keys) - # We (arbitrarily) fetch the next partition when we have seen 80% of the current partition - fetch_next_partition_idx = int(num_samples_on_this_partition * 0.8) + # We (arbitrarily) fetch the next partition when we have seen 70% of the current partition + fetch_next_partition_idx = int(num_samples_on_this_partition * 0.7) self._info(f"Train on partition {partition}, on {num_samples_on_this_partition} batches", worker_id) for idx, (key, sample, label) in enumerate(zip(keys, data, labels)): if partition < self._num_partitions - 1 and idx == fetch_next_partition_idx: # TODO(#175) in case this blocks training - new_keys, new_data, new_labels = self._get_data(worker_id=worker_id, partition_id=partition + 1) + self._run_get_data_thread(worker_id=worker_id, partition_id=partition + 1) # mypy complains here because _transform has unknown type, which is ok yield key, self._transform(sample), label # type: ignore @@ -199,8 +231,6 @@ def __iter__(self) -> Generator: del keys del data del labels - keys, data, labels = new_keys, new_data, new_labels - del new_keys - del new_data - del new_labels + # Wait for data fetching to finish + keys, data, labels = self._wait_for_data_thread(worker_id) gc.collect() From 0e54d4557e7344f19376b57d2577944bfdf8819f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= Date: Thu, 11 May 2023 11:08:39 +0200 Subject: [PATCH 2/2] fix linting issues --- .../internal/dataset/online_dataset.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index fd40e33b5..971fe0df9 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -57,6 +57,8 @@ def __init__( self._selectorstub: SelectorStub = None self._bytes_parser_function: Optional[Callable] = None self._num_partitions = 0 + self._data_thread: Optional[threading.Thread] = None + self._thread_data_container: dict[str, Any] = {} logger.debug("Initialized OnlineDataset.") @@ -139,7 +141,7 @@ def _info(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover def _debug(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover logger.debug(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}") - def _get_data(self, data_container: dict, worker_id: int, partition_id: int): + def _get_data(self, data_container: dict, worker_id: int, partition_id: int) -> None: self._info("Getting keys from selector", worker_id) keys = self._get_keys_from_selector(worker_id, partition_id) self._info("Getting data from storage", worker_id) @@ -149,9 +151,9 @@ def _get_data(self, data_container: dict, worker_id: int, partition_id: int): data_container["labels"] = labels data_container["keys"] = keys - def _run_get_data_thread(self, worker_id: int, partition_id: int): - assert "_data_thread" not in self.__dict__.keys() or self._data_thread is None - self._thread_data_container = {} + def _run_get_data_thread(self, worker_id: int, partition_id: int) -> None: + assert self._data_thread is None + self._data_thread = threading.Thread( target=self._get_data, args=(self._thread_data_container, worker_id, partition_id) ) @@ -172,7 +174,8 @@ def _wait_for_data_thread(self, worker_id: int) -> tuple[list[int], list[bytes], self._thread_data_container["data"], self._thread_data_container["labels"], ) - del self._thread_data_container + self._thread_data_container.clear() + gc.collect() return keys, data, labels