From f4b38db6881ea1eb0fd8bc21ec358ad88d98ab2f Mon Sep 17 00:00:00 2001 From: Michael Diggin Date: Sun, 19 Jan 2025 17:41:41 +0000 Subject: [PATCH 1/3] Add in-order flag and implementation --- test/stateful_dataloader/test_dataloader.py | 118 ++++++++++++++++++ .../stateful_dataloader.py | 58 +++++++-- 2 files changed, 167 insertions(+), 9 deletions(-) diff --git a/test/stateful_dataloader/test_dataloader.py b/test/stateful_dataloader/test_dataloader.py index 40ed43cdb..17abd0dc5 100644 --- a/test/stateful_dataloader/test_dataloader.py +++ b/test/stateful_dataloader/test_dataloader.py @@ -3024,6 +3024,124 @@ def test_conv_after_fork(self): self.assertEqual(x.shape, (1, 1, 1, 23999)) +class _TestSlowIndexDataset(Dataset): + def __init__(self, end: int, slow_index: int): + self.end = end + self.slow_index = slow_index + self._worker_id = None + + def __getitem__(self, idx): + if not self._worker_id: + worker_info = torch.utils.data.get_worker_info() + self._worker_id = worker_info.id + if idx == self.slow_index: + time.sleep(1.0) + return (self._worker_id, idx) + + def __len__(self): + return self.end + + +class _TestSlowIterableDataset(IterableDataset): + def __init__(self, start: int, end: int): + self.start = start + self.end = end + self.mid = math.ceil((self.end - self.start) / 2) + + def give_data(self, worker_id, iter_start, iter_end): + for i in range(iter_start, iter_end): + if i == self.mid: + time.sleep(1.0) + yield (worker_id, i) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + worker_id = worker_info.id + iter_start = self.start + worker_id * per_worker + iter_end = min(iter_start + per_worker, self.end) + return self.give_data(worker_id, iter_start, iter_end) + + +class TestOutOfOrderDataLoader(TestCase): + def test_in_order_index_ds(self): + dataset = _TestSlowIndexDataset(end=10, slow_index=0) + + dataloader = DataLoader( + dataset, + num_workers=2, + in_order=True, + ) + + expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + outputs = list(dataloader) + worker_ids = [o[0] for o in outputs] + data = [o[1] for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertEqual(expected_data, data) + + def test_out_of_order_index_ds(self): + dataset = _TestSlowIndexDataset(end=10, slow_index=0) + + dataloader = DataLoader( + dataset, + num_workers=2, + prefetch_factor=2, + in_order=False, + ) + + # worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue + # due to prefetch_factor being 2 + # this makes the test more deterministic as [0, 2] will be the last elements + expected_worker_ids = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0] + expected_data = [1, 3, 4, 5, 6, 7, 8, 9, 0, 2] + outputs = list(dataloader) + worker_ids = [o[0].item() for o in outputs] + data = [o[1].item() for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertNotEqual(data, list(range(10))) + self.assertEqual(expected_data, data) + + def test_in_order_iterable_ds(self): + dataset = _TestSlowIterableDataset(start=0, end=10) + + dataloader = DataLoader( + dataset, + num_workers=2, + in_order=True, + ) + + expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + expected_data = [0, 5, 1, 6, 2, 7, 3, 8, 4, 9] + outputs = list(dataloader) + worker_ids = [o[0] for o in outputs] + data = [o[1] for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertEqual(expected_data, data) + + def test_out_of_order_iterable_ds(self): + dataset = _TestSlowIterableDataset(start=0, end=10) + + dataloader = DataLoader( + dataset, + num_workers=2, + in_order=False, + ) + + # worker 0 has [0, 1, 2, 3, 4], worker 1 has [5, 6, 7, 8, 9] + # index 5 is slow, so expect all of worker 0 before worker 1 + expected_worker_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + outputs = list(dataloader) + worker_ids = [o[0] for o in outputs] + data = [o[1] for o in outputs] + self.assertEqual(expected_worker_ids, worker_ids) + self.assertEqual(sum(worker_ids), 5) + self.assertNotEqual(data, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]) + self.assertEqual(expected_data, data) + + instantiate_device_type_tests(TestDataLoaderDeviceType, globals()) diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 9b162b4f8..5f4522af6 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -147,6 +147,8 @@ class StatefulDataLoader(DataLoader[_T_co]): maintain the workers `Dataset` instances alive. (default: ``False``) pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is ``True``. + in_order (bool, optional): If ``False``, the data loader will not enforce that batches + are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) snapshot_every_n_steps (int, optional): Defines how often the state is transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step. @@ -177,6 +179,10 @@ class StatefulDataLoader(DataLoader[_T_co]): .. warning:: See `Reproducibility `_, and `Dataloader-workers-random-seed `_, and `Data-loading-randomness `_ notes for random seed related questions. + .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data. + + .. warning:: Setting `in_order` to `False` currently has no guarantees for state management. + .. _multiprocessing context: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods """ @@ -202,6 +208,7 @@ def __init__( prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = "", + in_order: bool = True, snapshot_every_n_steps: Optional[int] = 1, ): torch._C._log_api_usage_once("python.stateful_data_loader") @@ -227,6 +234,13 @@ def __init__( if persistent_workers and num_workers == 0: raise ValueError("persistent_workers option needs num_workers > 0") + if num_workers > 0 and not in_order: + # TODO: remove warning log when state management is supported with in_order=False + logger.warning( + "using in_order=False with multiple workers does not give any guarantees for state management " + "and loading from a checkpoint may not work as expected." + ) + self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor @@ -235,6 +249,7 @@ def __init__( self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context + self.in_order = in_order # Adds forward compatibilities so classic DataLoader can work with DataPipes: # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler @@ -876,6 +891,7 @@ def __init__(self, loader, next_iter_state): super().__init__(loader) self._snapshot_interval = loader.snapshot_every_n_steps self._prefetch_factor = loader.prefetch_factor + self._in_order = loader.in_order assert self._num_workers > 0 assert self._prefetch_factor > 0 @@ -1083,6 +1099,11 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. self._workers_status = [True for i in range(self._num_workers)] + # A list of integers representing how many tasks are outstanding for each worker + # Incremented when a task is dispatched to the worker + # Decremented when that data has been given to the main thread + # Each worker should have at most self._prefetch_factor tasks outstanding + self._workers_num_tasks = [0 for i in range(self._num_workers)] # Reset the worker queue cycle so it resumes next epoch at worker 0 self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) remaining = self._num_workers @@ -1352,11 +1373,12 @@ def _next_data(self): # call and `_IterableDatasetStopIteration` check below can mark # extra worker(s) as dead. while self._rcvd_idx < self._send_idx: - info = self._task_info[self._rcvd_idx] - worker_id = info[0] - if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active - break - del self._task_info[self._rcvd_idx] + info = self._task_info.get(self._rcvd_idx, None) + if info: + worker_id = info[0] + if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active + break + del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 else: # no valid `self._rcvd_idx` is found (i.e., didn't break) @@ -1374,6 +1396,7 @@ def _next_data(self): self._rcvd_idx += 1 continue else: + self._rcvd_idx += 1 return self._process_data(data, worker_id, state_dict) assert not self._shutdown and self._tasks_outstanding > 0 @@ -1394,6 +1417,13 @@ def _next_data(self): if idx != self._rcvd_idx: # store out-of-order samples + if not self._in_order: + # don't store it for later, process now + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + continue + del self._task_info[idx] + return self._process_data(data, worker_id, state_dict) self._task_info[idx] += ((data, worker_id, state_dict),) else: del self._task_info[idx] @@ -1402,6 +1432,7 @@ def _next_data(self): self._rcvd_idx += 1 continue else: + self._rcvd_idx += 1 return self._process_data(data, worker_id, state_dict) def _get_main_state(self): @@ -1433,7 +1464,8 @@ def _restore_main_state(self, state_dict): self._base_seed = state_dict[self._BASE_SEED] def _try_put_index(self): - assert self._tasks_outstanding < self._prefetch_factor * self._num_workers + max_tasks = self._prefetch_factor * self._num_workers + assert self._tasks_outstanding < max_tasks try: index = self._next_index() @@ -1461,7 +1493,12 @@ def _try_put_index(self): for _ in range(self._num_workers): # find the next active worker, if any worker_queue_idx = next(self._worker_queue_idx_cycle) if self._workers_status[worker_queue_idx]: - break + if self._in_order: + break + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status): + # when self._in_order is False, distribute work to a worker if it has capacity + # _workers_status is updated only in this thread, so the sum is guaranteed > 0 + break else: # not found (i.e., didn't break) return @@ -1472,11 +1509,12 @@ def _try_put_index(self): self._index_queues[worker_queue_idx].put((self._send_idx, (index, snapshot))) # type: ignore[possibly-undefined] self._task_info[self._send_idx] = (worker_queue_idx,) + self._workers_num_tasks[worker_queue_idx] += 1 self._tasks_outstanding += 1 self._send_idx += 1 def _process_data(self, data, worker_id, state_dict): - self._rcvd_idx += 1 + self._workers_num_tasks[worker_id] -= 1 self._try_put_index() if isinstance(data, ExceptionWrapper): data.reraise() @@ -1489,9 +1527,11 @@ def _process_data(self, data, worker_id, state_dict): return data def _take_snapshot(self): + main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1): main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() - assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) + if self._in_order: + assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) self._update_snapshot( self._num_yielded + 1, self._last_yielded_worker_id, From 43622b54fcc46bde0cfd2e22c6aa59054434b4f3 Mon Sep 17 00:00:00 2001 From: Michael Diggin Date: Mon, 20 Jan 2025 08:00:22 +0000 Subject: [PATCH 2/3] handle snapshotting edge case --- torchdata/stateful_dataloader/stateful_dataloader.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 5f4522af6..19fff5570 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -1527,11 +1527,14 @@ def _process_data(self, data, worker_id, state_dict): return data def _take_snapshot(self): - main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() + main_snapshot_idx = None while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1): main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() - if self._in_order: - assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) + if not self._in_order and main_snapshot_idx is None: + # in_order is False and no main snapshot is available as we're ahead of rcvd_idx + # we can't take a snapshot with the current implementation + return + assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) self._update_snapshot( self._num_yielded + 1, self._last_yielded_worker_id, From 57ba9e5ad79ec19e92e1319c68e04d16c90cebd4 Mon Sep 17 00:00:00 2001 From: Michael Diggin Date: Wed, 22 Jan 2025 07:42:51 +0000 Subject: [PATCH 3/3] add warning log in state_dict call --- torchdata/stateful_dataloader/stateful_dataloader.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 19fff5570..078b378ee 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -1170,6 +1170,12 @@ def _update_worker_snapshot(self, worker_key, state_dict): self._worker_snapshots[worker_key].apply_delta(state_dict) def state_dict(self): + if not self._in_order: + # TODO: remove warning log when state management is supported with in_order=False + logger.warning( + "using in_order=False with multiple workers does not give any guarantees for state management " + "and loading from a checkpoint may not work as expected." + ) steps_since_snapshot = self._num_yielded - self._snapshot[self._SNAPSHOT_STEP] state_dict = { self._SNAPSHOT: self._snapshot,