Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stateful DL] Add out of order implementation #1423

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,6 +3024,124 @@ def test_conv_after_fork(self):
self.assertEqual(x.shape, (1, 1, 1, 23999))


class _TestSlowIndexDataset(Dataset):
def __init__(self, end: int, slow_index: int):
self.end = end
self.slow_index = slow_index
self._worker_id = None

def __getitem__(self, idx):
if not self._worker_id:
worker_info = torch.utils.data.get_worker_info()
self._worker_id = worker_info.id
if idx == self.slow_index:
time.sleep(1.0)
return (self._worker_id, idx)

def __len__(self):
return self.end


class _TestSlowIterableDataset(IterableDataset):
def __init__(self, start: int, end: int):
self.start = start
self.end = end
self.mid = math.ceil((self.end - self.start) / 2)

def give_data(self, worker_id, iter_start, iter_end):
for i in range(iter_start, iter_end):
if i == self.mid:
time.sleep(1.0)
yield (worker_id, i)

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return self.give_data(worker_id, iter_start, iter_end)


class TestOutOfOrderDataLoader(TestCase):
def test_in_order_index_ds(self):
dataset = _TestSlowIndexDataset(end=10, slow_index=0)

dataloader = DataLoader(
dataset,
num_workers=2,
in_order=True,
)

expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
outputs = list(dataloader)
worker_ids = [o[0] for o in outputs]
data = [o[1] for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertEqual(expected_data, data)

def test_out_of_order_index_ds(self):
dataset = _TestSlowIndexDataset(end=10, slow_index=0)

dataloader = DataLoader(
dataset,
num_workers=2,
prefetch_factor=2,
in_order=False,
)

# worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue
# due to prefetch_factor being 2
# this makes the test more deterministic as [0, 2] will be the last elements
expected_worker_ids = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
expected_data = [1, 3, 4, 5, 6, 7, 8, 9, 0, 2]
outputs = list(dataloader)
worker_ids = [o[0].item() for o in outputs]
data = [o[1].item() for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertNotEqual(data, list(range(10)))
self.assertEqual(expected_data, data)

def test_in_order_iterable_ds(self):
dataset = _TestSlowIterableDataset(start=0, end=10)

dataloader = DataLoader(
dataset,
num_workers=2,
in_order=True,
)

expected_worker_ids = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
expected_data = [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]
outputs = list(dataloader)
worker_ids = [o[0] for o in outputs]
data = [o[1] for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertEqual(expected_data, data)

def test_out_of_order_iterable_ds(self):
dataset = _TestSlowIterableDataset(start=0, end=10)

dataloader = DataLoader(
dataset,
num_workers=2,
in_order=False,
)

# worker 0 has [0, 1, 2, 3, 4], worker 1 has [5, 6, 7, 8, 9]
# index 5 is slow, so expect all of worker 0 before worker 1
expected_worker_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
expected_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
outputs = list(dataloader)
worker_ids = [o[0] for o in outputs]
data = [o[1] for o in outputs]
self.assertEqual(expected_worker_ids, worker_ids)
self.assertEqual(sum(worker_ids), 5)
self.assertNotEqual(data, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])
self.assertEqual(expected_data, data)


instantiate_device_type_tests(TestDataLoaderDeviceType, globals())


Expand Down
65 changes: 57 additions & 8 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class StatefulDataLoader(DataLoader[_T_co]):
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
in_order (bool, optional): If ``False``, the data loader will not enforce that batches
are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``)
snapshot_every_n_steps (int, optional): Defines how often the state is
transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step.

Expand Down Expand Up @@ -177,6 +179,10 @@ class StatefulDataLoader(DataLoader[_T_co]):
.. warning:: See `Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html#reproducibility>`_, and `Dataloader-workers-random-seed <https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed>`_, and
`Data-loading-randomness <https://pytorch.org/docs/stable/data.html#data-loading-randomness>`_ notes for random seed related questions.

.. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data.

.. warning:: Setting `in_order` to `False` currently has no guarantees for state management.

.. _multiprocessing context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""
Expand All @@ -202,6 +208,7 @@ def __init__(
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
in_order: bool = True,
snapshot_every_n_steps: Optional[int] = 1,
):
torch._C._log_api_usage_once("python.stateful_data_loader")
Expand All @@ -227,6 +234,13 @@ def __init__(
if persistent_workers and num_workers == 0:
raise ValueError("persistent_workers option needs num_workers > 0")

if num_workers > 0 and not in_order:
# TODO: remove warning log when state management is supported with in_order=False
logger.warning(
"using in_order=False with multiple workers does not give any guarantees for state management "
"and loading from a checkpoint may not work as expected."
)

self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
Expand All @@ -235,6 +249,7 @@ def __init__(
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
self.in_order = in_order

# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
Expand Down Expand Up @@ -876,6 +891,7 @@ def __init__(self, loader, next_iter_state):
super().__init__(loader)
self._snapshot_interval = loader.snapshot_every_n_steps
self._prefetch_factor = loader.prefetch_factor
self._in_order = loader.in_order

assert self._num_workers > 0
assert self._prefetch_factor > 0
Expand Down Expand Up @@ -1083,6 +1099,11 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True):
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
self._workers_status = [True for i in range(self._num_workers)]
# A list of integers representing how many tasks are outstanding for each worker
# Incremented when a task is dispatched to the worker
# Decremented when that data has been given to the main thread
# Each worker should have at most self._prefetch_factor tasks outstanding
self._workers_num_tasks = [0 for i in range(self._num_workers)]
# Reset the worker queue cycle so it resumes next epoch at worker 0
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
remaining = self._num_workers
Expand Down Expand Up @@ -1149,6 +1170,12 @@ def _update_worker_snapshot(self, worker_key, state_dict):
self._worker_snapshots[worker_key].apply_delta(state_dict)

def state_dict(self):
if not self._in_order:
# TODO: remove warning log when state management is supported with in_order=False
logger.warning(
"using in_order=False with multiple workers does not give any guarantees for state management "
"and loading from a checkpoint may not work as expected."
)
steps_since_snapshot = self._num_yielded - self._snapshot[self._SNAPSHOT_STEP]
state_dict = {
self._SNAPSHOT: self._snapshot,
Expand Down Expand Up @@ -1352,11 +1379,12 @@ def _next_data(self):
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
info = self._task_info.get(self._rcvd_idx, None)
if info:
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
Expand All @@ -1374,6 +1402,7 @@ def _next_data(self):
self._rcvd_idx += 1
continue
else:
self._rcvd_idx += 1
return self._process_data(data, worker_id, state_dict)

assert not self._shutdown and self._tasks_outstanding > 0
Expand All @@ -1394,6 +1423,13 @@ def _next_data(self):

if idx != self._rcvd_idx:
# store out-of-order samples
if not self._in_order:
# don't store it for later, process now
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict)
continue
del self._task_info[idx]
return self._process_data(data, worker_id, state_dict)
self._task_info[idx] += ((data, worker_id, state_dict),)
else:
del self._task_info[idx]
Expand All @@ -1402,6 +1438,7 @@ def _next_data(self):
self._rcvd_idx += 1
continue
else:
self._rcvd_idx += 1
return self._process_data(data, worker_id, state_dict)

def _get_main_state(self):
Expand Down Expand Up @@ -1433,7 +1470,8 @@ def _restore_main_state(self, state_dict):
self._base_seed = state_dict[self._BASE_SEED]

def _try_put_index(self):
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
max_tasks = self._prefetch_factor * self._num_workers
assert self._tasks_outstanding < max_tasks

try:
index = self._next_index()
Expand Down Expand Up @@ -1461,7 +1499,12 @@ def _try_put_index(self):
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
if self._in_order:
break
elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status):
# when self._in_order is False, distribute work to a worker if it has capacity
# _workers_status is updated only in this thread, so the sum is guaranteed > 0
break
else:
# not found (i.e., didn't break)
return
Expand All @@ -1472,11 +1515,12 @@ def _try_put_index(self):

self._index_queues[worker_queue_idx].put((self._send_idx, (index, snapshot))) # type: ignore[possibly-undefined]
self._task_info[self._send_idx] = (worker_queue_idx,)
self._workers_num_tasks[worker_queue_idx] += 1
self._tasks_outstanding += 1
self._send_idx += 1

def _process_data(self, data, worker_id, state_dict):
self._rcvd_idx += 1
self._workers_num_tasks[worker_id] -= 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
Expand All @@ -1489,8 +1533,13 @@ def _process_data(self, data, worker_id, state_dict):
return data

def _take_snapshot(self):
main_snapshot_idx = None
while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1):
main_snapshot_idx, main_snapshot = self._main_snapshots.popleft()
if not self._in_order and main_snapshot_idx is None:
# in_order is False and no main snapshot is available as we're ahead of rcvd_idx
# we can't take a snapshot with the current implementation
return
assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1)
self._update_snapshot(
self._num_yielded + 1,
Expand Down
Loading