Skip to content

Commit

Permalink
Saving and restoring initial seed generator (#998)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #998

Changes to `DataLoader2`:
- Modifying `state_dict` to store `randomness_state`, which includes:
    - `_seed: int`
    - `_reset_seed: bool` - flag indicating whether `_seed` needs to be set
    - `_seed_generator` - the latest version at the time when `state_dict` is called
    - `_initial_seed_generator` - the versopm that is saved at the beginning of very epoch
- Modifying `from_state` and `load_state_dict` to restore `randomness_state`
- Adding a method `_restore_checkpoint_beginning_of_epoch`
    -  This sets `self._seed_generator = self._initial_seed_generator`, allowing users to re-create an epoch from the beginning.

 ---
### Considerations

Storing the randomness states provide more flexibility for users to restore as they see fit. The decision to do that should not be controversial.

I decided to make add a new method for checkpointing at the beginning of the epoch, ensure that users are not confused about what randomness is restored by default.

The basic idea is that we want to allow users to restore `dl2._seed_generator` to the previously saved version. From that point on, they can create a new `__iter__` and continue from the beginning of the epoch.
- Note that since `_seed` and `_reset_seed` are also saved, if the users were planning to use a different seed or if there was a need to re-seed, those remain valid after restoring the checkpoint.
- Finally, if users change their mind at any point (after restoring) and want to manual set `seed`. That `seed` will override any other behavior and the `seed` will be used.

Test Plan: Imported from OSS

Reviewed By: wenleix

Differential Revision: D44390519

Pulled By: NivekT

fbshipit-source-id: 0faa3d7aef23463e86765571018f83384bcb31e1
  • Loading branch information
NivekT authored and facebook-github-bot committed Mar 28, 2023
1 parent aeda987 commit 4ba17dd
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 5 deletions.
66 changes: 62 additions & 4 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import multiprocessing as mp
import unittest
from unittest import TestCase
Expand All @@ -14,7 +13,7 @@
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


def _add_one(x: int) -> int:
Expand Down Expand Up @@ -46,6 +45,17 @@ def _dispatching_dp(n_elements=1000):
return dp


class NonShardableDataPipe(IterDataPipe):
def __init__(self, dp: IterDataPipe):
self.dp = dp

def is_replicable(self):
return False

def __iter__(self):
yield from self.dp


class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of MultiProcessingReadingService, notably
Expand All @@ -64,7 +74,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(10):
_ = next(it)
Expand All @@ -82,7 +92,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
_ = list(dl)
dl.shutdown()

Expand Down Expand Up @@ -248,6 +258,54 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
res.append(x)
self.assertEqual(9, len(res))

def test_initial_epoch_checkpointing(self):
dp = IterableWrapper(range(20)).shuffle().sharding_filter()
# Note that the second `shuffle` occurs in the main process, which uses a different RNG from
# the `shuffle` done in the worker processes
dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type]
rs = MultiProcessingReadingService(num_workers=2)

# Functional Test: Saving state before iterator is created
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
initial_state = dl.state_dict()
it1 = iter(dl)

restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created and began iterating
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
temp = next(it1) # Starts iterating
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()

self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result

dl.shutdown()
restored_dl.shutdown()

# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
# Currently, using sharding_round_robin raises a warning
# def test_round_robin_dispatching_pause_limit(self):
Expand Down
37 changes: 36 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import pickle
import warnings

from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union
Expand All @@ -19,6 +19,7 @@
T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"
RANDOMNESS_STATE_KEY_NAME = "randomness_state"


class DataLoader2Iterator(Iterator[T_co]):
Expand Down Expand Up @@ -176,6 +177,8 @@ def __init__(
self._seed_generator: SeedGenerator = SeedGenerator()
self._seed: Optional[int] = None
self._reset_seed: bool = True
# Seed generator as of beginning of each epoch
self._initial_seed_generator: SeedGenerator = clone(self._seed_generator)

def __iter__(self) -> DataLoader2Iterator[T_co]:
r"""
Expand All @@ -198,6 +201,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
else:
self._seed_generator.seed()

# Saving initial seed generator state
self._initial_seed_generator = clone(self._seed_generator)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand Down Expand Up @@ -269,10 +275,17 @@ def state_dict(self) -> Dict[str, Any]:

# Serialize datapipe after applying adapters and before reading service adaption
serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt)
serialized_randomness_state = (
self._seed,
self._reset_seed,
pickle.dumps(self._seed_generator),
pickle.dumps(self._initial_seed_generator),
)

return {
SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe,
READING_SERVICE_STATE_KEY_NAME: reading_service_state,
RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state,
}

@classmethod
Expand All @@ -294,6 +307,12 @@ def from_state(
reading_service=reading_service,
)
data_loader.reading_service_state = reading_service_state

randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
data_loader._seed_generator = pickle.loads(randomness_state[2])
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])

return data_loader

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -320,12 +339,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.datapipe = deserialized_datapipe
self.reading_service_state = reading_service_state

randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
self._seed_generator = pickle.loads(randomness_state[2])
self._initial_seed_generator = pickle.loads(randomness_state[3])

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt = clone(self.datapipe)

def _restore_checkpoint_beginning_of_epoch(self) -> None:
r"""
At the beginning of each iteration (epoch), the initial state of randomness is automatically saved.
That state is also saved as part of ``state_dict``. This method restores the current DataLoader2 RNG state
to that initial state.
The common use case is to invoke this method after ``DataLoader2``'s state is restored (through
``.from_state(...)`` or ``load_state_dict(...)``) in order to resume from the beginning of the last-ran epoch.
"""
self._seed_generator = self._initial_seed_generator

def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
Expand Down
10 changes: 10 additions & 0 deletions torchdata/dataloader2/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,13 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator":
self._worker_rng = self._worker_rng.spawn(worker_id)
return self
return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id)))

def __getstate__(self):
state = (
self._shared_rng,
self._worker_rng,
)
return state

def __setstate__(self, state):
self._shared_rng, self._worker_rng = state
2 changes: 2 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def initialize_iteration(
) -> Optional[Callable[[DataPipe], DataPipe]]:
assert self._end_datapipe is not None

# Set random seeds for DataPipe that are in the main process (NOT those in worker processes)
# Worker seeds are set in `process_reset_fn`
set_graph_random_seed(self._end_datapipe, seed_generator)

if self._mp:
Expand Down

0 comments on commit 4ba17dd

Please sign in to comment.