Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous prefetching of data #253

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions modyn/trainer_server/internal/dataset/online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import logging
import threading
from inspect import isfunction
from typing import Any, Callable, Generator, Optional

Expand Down Expand Up @@ -56,6 +57,8 @@ def __init__(
self._selectorstub: SelectorStub = None
self._bytes_parser_function: Optional[Callable] = None
self._num_partitions = 0
self._data_thread: Optional[threading.Thread] = None
self._thread_data_container: dict[str, Any] = {}

logger.debug("Initialized OnlineDataset.")

Expand Down Expand Up @@ -138,11 +141,42 @@ def _info(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover
def _debug(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cover
logger.debug(f"[Training {self._training_id}][PL {self._pipeline_id}][Worker {worker_id}] {msg}")

def _get_data(self, worker_id: int, partition_id: int) -> tuple[list[int], list[bytes], list[int]]:
def _get_data(self, data_container: dict, worker_id: int, partition_id: int) -> None:
self._info("Getting keys from selector", worker_id)
keys = self._get_keys_from_selector(worker_id, partition_id)
self._info("Getting data from storage", worker_id)
data, labels = self._get_data_from_storage(keys)

data_container["data"] = data
data_container["labels"] = labels
data_container["keys"] = keys

def _run_get_data_thread(self, worker_id: int, partition_id: int) -> None:
assert self._data_thread is None

self._data_thread = threading.Thread(
target=self._get_data, args=(self._thread_data_container, worker_id, partition_id)
)
self._data_thread.start()

def _wait_for_data_thread(self, worker_id: int) -> tuple[list[int], list[bytes], list[int]]:
assert self._data_thread is not None
self._info("Joining data thread.", worker_id)
self._data_thread.join()
self._data_thread = None

assert "data" in self._thread_data_container
assert "labels" in self._thread_data_container
assert "keys" in self._thread_data_container

keys, data, labels = (
self._thread_data_container["keys"],
self._thread_data_container["data"],
self._thread_data_container["labels"],
)
self._thread_data_container.clear()
gc.collect()

return keys, data, labels

def _get_num_data_partitions(self) -> int:
Expand Down Expand Up @@ -179,18 +213,19 @@ def __iter__(self) -> Generator:
self._num_partitions = self._get_num_data_partitions()
self._info(f"Total number of partitions will be {self._num_partitions}", worker_id)

keys, data, labels = self._get_data(worker_id=worker_id, partition_id=0)
self._run_get_data_thread(worker_id=worker_id, partition_id=0)
keys, data, labels = self._wait_for_data_thread(worker_id)

for partition in range(self._num_partitions):
num_samples_on_this_partition = len(keys)
# We (arbitrarily) fetch the next partition when we have seen 80% of the current partition
fetch_next_partition_idx = int(num_samples_on_this_partition * 0.8)
# We (arbitrarily) fetch the next partition when we have seen 70% of the current partition
fetch_next_partition_idx = int(num_samples_on_this_partition * 0.7)
self._info(f"Train on partition {partition}, on {num_samples_on_this_partition} batches", worker_id)

for idx, (key, sample, label) in enumerate(zip(keys, data, labels)):
if partition < self._num_partitions - 1 and idx == fetch_next_partition_idx:
# TODO(#175) in case this blocks training
new_keys, new_data, new_labels = self._get_data(worker_id=worker_id, partition_id=partition + 1)
self._run_get_data_thread(worker_id=worker_id, partition_id=partition + 1)
# mypy complains here because _transform has unknown type, which is ok
yield key, self._transform(sample), label # type: ignore

Expand All @@ -199,8 +234,6 @@ def __iter__(self) -> Generator:
del keys
del data
del labels
keys, data, labels = new_keys, new_data, new_labels
del new_keys
del new_data
del new_labels
# Wait for data fetching to finish
keys, data, labels = self._wait_for_data_thread(worker_id)
gc.collect()