From 1c69775eb6c82be8fdc0dae44096a89bb4d5f8b2 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 3 Feb 2025 15:32:39 -0800 Subject: [PATCH 01/24] add test for end of epoch state dict check --- test/stateful_dataloader/test_state_dict.py | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 327b97a5d..f076c5ddb 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -1440,6 +1440,47 @@ def test_fast_state_dict_request(self) -> None: def test_fast_state_dict_request_skip_steps(self) -> None: self._run_test(17, 19) +class TestMultiEpochState(TestCase): + def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): + dataset = DummyMapDataset(data_size, shuffle=shuffle) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), + ) + + def _run(self, data_size, num_workers, batch_size, shuffle=False): + dataloader = self.get_map_dl(data_size=data_size,num_workers=num_workers, batch_size=batch_size, shuffle=shuffle) + # Run through the dataloader for 2 epochs and count the number of items yielded + num_items_yielded = 0 + for _ in range(2): + for _ in dataloader: + num_items_yielded += 1 + # Save the state dict + state_dict = dataloader.state_dict() + # Create a new StatefulDataLoader instance and load the state dict + new_dataloader = self.get_map_dl( + num_workers=num_workers, batch_size=batch_size, shuffle=shuffle + ) + new_dataloader.load_state_dict(state_dict) + # Run through the new dataloader for another 2 epochs and count the number of items yielded + additional_num_items_yielded = 0 + for i in range(2): + epoch_num_items_yielded = 0 + for _ in new_dataloader: + epoch_num_items_yielded += 1 + additional_num_items_yielded += epoch_num_items_yielded + # Check that the total number of items yielded is correct + self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size*4) + + + def test_main_process(self): + self._run(100, 0, 1, False) + def test_multiprocess(self): + self._run(100, 2, 1, False) class TestMultiEpochState_shard0(TestCase): def get_iterable_dl(self, pw, num_workers): From a074b5036616aa73c8e21a0bf2abdc3a617ce08a Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 3 Feb 2025 15:34:49 -0800 Subject: [PATCH 02/24] run precommit update stateful_dataloader run precommit local changes update test to test the order of batches update test update tests revert changes in SDL revert changes in SDL update tests run precommit --- test/stateful_dataloader/test_state_dict.py | 67 ++++++++++++++++----- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index f076c5ddb..15ecbe037 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -1314,7 +1314,7 @@ def test(self): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) # Fetch at least one batch from each worker @@ -1325,7 +1325,10 @@ def test(self): if num_workers > 0: for i in range(num_workers): # Ensure worker state is stored only once if the dataset is also the iterator - self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None) + self.assertEqual( + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], + None, + ) self.assertTrue( state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ "dataset_iter_state" @@ -1440,48 +1443,80 @@ def test_fast_state_dict_request(self) -> None: def test_fast_state_dict_request_skip_steps(self) -> None: self._run_test(17, 19) -class TestMultiEpochState(TestCase): + +class TestMultiEpochSDL_shard0(TestCase): def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): - dataset = DummyMapDataset(data_size, shuffle=shuffle) + dataset = DummyMapDataset(data_size, shuffle=False) return StatefulDataLoader( dataset=dataset, num_workers=num_workers, batch_size=batch_size, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + shuffle=shuffle, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) - + def _run(self, data_size, num_workers, batch_size, shuffle=False): - dataloader = self.get_map_dl(data_size=data_size,num_workers=num_workers, batch_size=batch_size, shuffle=shuffle) + dl1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) # Run through the dataloader for 2 epochs and count the number of items yielded num_items_yielded = 0 + dl1_items = [] for _ in range(2): - for _ in dataloader: + for batch in dl1: + dl1_items.append(batch) num_items_yielded += 1 # Save the state dict - state_dict = dataloader.state_dict() + state_dict = dl1.state_dict() # Create a new StatefulDataLoader instance and load the state dict - new_dataloader = self.get_map_dl( - num_workers=num_workers, batch_size=batch_size, shuffle=shuffle + new_dl1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, ) - new_dataloader.load_state_dict(state_dict) + new_dl1.load_state_dict(state_dict) # Run through the new dataloader for another 2 epochs and count the number of items yielded additional_num_items_yielded = 0 for i in range(2): epoch_num_items_yielded = 0 - for _ in new_dataloader: + for batch in new_dl1: + dl1_items.append(batch) epoch_num_items_yielded += 1 additional_num_items_yielded += epoch_num_items_yielded # Check that the total number of items yielded is correct - self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size*4) + self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) + + # now run a second dataloder for 4 epochs and check if the order is same. + dl2 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dl2_items = [] + for _ in range(4): + for batch in dl2: + dl2_items.append(batch) + self.assertEqual(dl1_items, dl2_items) def test_main_process(self): self._run(100, 0, 1, False) + def test_multiprocess(self): self._run(100, 2, 1, False) + def test_main_process_shuffle(self): + self._run(100, 0, 1, True) + + def test_multiprocess_shuffle(self): + self._run(100, 2, 1, True) + + class TestMultiEpochState_shard0(TestCase): def get_iterable_dl(self, pw, num_workers): data_size = [25, 50, 100, 75] From 50271b4e7effc3baeb9ce2a5ff66432fbade249d Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 4 Feb 2025 17:09:24 -0800 Subject: [PATCH 03/24] update sampler --- torchdata/stateful_dataloader/sampler.py | 94 +++++++++++++----------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 4effec1d1..b9214bdd4 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -48,7 +48,11 @@ def state_dict(self) -> Dict[str, Any]: class RandomSampler(torch.utils.data.sampler.RandomSampler): def __init__( - self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None + self, + data_source: Sized, + replacement: bool = False, + num_samples: Optional[int] = None, + generator=None, ): if generator is None: # Ensure that underlying sampler has something repeatable @@ -60,16 +64,31 @@ def __iter__(self): return _StatefulRandomSamplerIterator(self, super().__iter__()) -class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful): +class _BatchSamplerIterator(Iterator[list[int]], Stateful): _SAMPLES_YIELDED = "samples_yielded" _SAMPLER_STATE = "sampler_state" _SAMPLER_ITER_STATE = "sampler_iter_state" - def __init__(self, sampler, batch_size, drop_last): - super().__init__(sampler, batch_size, drop_last) + def __init__(self, sampler, batch_size: int, drop_last: bool): + self.sampler = sampler + self.sampler_iter = iter(self.sampler) + self.batch_size = batch_size + self.drop_last = drop_last self.samples_yielded = 0 - self.next_yielded = None - self.sampler_iter = iter(sampler) + + def __next__(self) -> list[int]: + batch = [] + try: + for _ in range(self.batch_size): + batch.append(next(self.sampler_iter)) + self.samples_yielded += 1 + return batch + except StopIteration: + if self.drop_last or len(batch) == 0: + # Reset the iterator for the next epoch + raise StopIteration + else: + return batch def state_dict(self) -> Dict[str, Any]: sd: Dict[str, Any] = {self._SAMPLES_YIELDED: self.samples_yielded} @@ -80,7 +99,7 @@ def state_dict(self) -> Dict[str, Any]: return sd def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self.next_yielded = state_dict[self._SAMPLES_YIELDED] + self.samples_yielded = state_dict[self._SAMPLES_YIELDED] if self._SAMPLER_STATE in state_dict: assert isinstance(self.sampler, Stateful) self.sampler.load_state_dict(state_dict[self._SAMPLER_STATE]) @@ -88,45 +107,32 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self._SAMPLER_ITER_STATE in state_dict: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) + if not ( + isinstance(self.sampler, Stateful) + or isinstance(self.sampler_iter, Stateful) + ): + # We skip x samples if underlying sampler is not stateful + for _ in range(self.samples_yielded): + next(self.sampler_iter) + # Skip one epoch if we were at the end of the last epoch + if hasattr(self.sampler, "__len__") and self.samples_yielded == len( + self.sampler + ): + + for _ in self.sampler_iter: + pass + + +class BatchSampler(torch.utils.data.sampler.BatchSampler): + def __init__(self, sampler, batch_size, drop_last): + super().__init__(sampler, batch_size, drop_last) def __iter__(self): - if self.next_yielded is not None: - self.samples_yielded = self.next_yielded - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): - # We skip x samples if underlying sampler is not stateful - for _ in range(self.next_yielded): - next(self.sampler_iter) - self.next_yielded = None - elif self.samples_yielded > 0: - # don't re-create sampler_iter unless necessary, we may already have one from init - self.sampler_iter = iter(self.sampler) - self.samples_yielded = 0 - - if self.drop_last: - while True: - try: - batch = [] - for _ in range(self.batch_size): - batch.append(next(self.sampler_iter)) - self.samples_yielded += 1 - yield batch - except StopIteration: - break - else: - batch = [0] * self.batch_size - idx_in_batch = 0 - for idx in self.sampler_iter: - self.samples_yielded += 1 - batch[idx_in_batch] = idx - idx_in_batch += 1 - if idx_in_batch == self.batch_size: - yield batch - idx_in_batch = 0 - batch = [0] * self.batch_size - if idx_in_batch > 0: - yield batch[:idx_in_batch] + return _BatchSamplerIterator( + sampler=self.sampler, + batch_size=self.batch_size, + drop_last=self.drop_last, + ) class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): From 6ba9d9437f0b256a3ce61bba3a4a00d864ff0046 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 4 Feb 2025 19:38:52 -0800 Subject: [PATCH 04/24] run precommit --- torchdata/stateful_dataloader/sampler.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index b9214bdd4..336129190 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -107,18 +107,16 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self._SAMPLER_ITER_STATE in state_dict: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not ( - isinstance(self.sampler, Stateful) - or isinstance(self.sampler_iter, Stateful) + + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler ): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) - # Skip one epoch if we were at the end of the last epoch - if hasattr(self.sampler, "__len__") and self.samples_yielded == len( - self.sampler - ): + # Skip one epoch if we were at the end of the last epoch + if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler): for _ in self.sampler_iter: pass From 1288e772c68a661882080c817cf047e6b38dc6b3 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 4 Feb 2025 21:35:53 -0800 Subject: [PATCH 05/24] remove unnecessary comment --- torchdata/stateful_dataloader/sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 336129190..308ccd7a8 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -85,7 +85,6 @@ def __next__(self) -> list[int]: return batch except StopIteration: if self.drop_last or len(batch) == 0: - # Reset the iterator for the next epoch raise StopIteration else: return batch From 297d7bfd508dd256c1dadc97aa8ec5e8deab6d89 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 09:03:30 -0800 Subject: [PATCH 06/24] add test for statedict before and after endofepoch --- test/stateful_dataloader/test_state_dict.py | 298 ++++++++++++++++---- 1 file changed, 240 insertions(+), 58 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 15ecbe037..f8dc86e60 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -83,7 +83,9 @@ def __len__(self): return self.size -class DummyIteratorIterableDataset(torch.utils.data.IterableDataset, Iterator, Stateful): +class DummyIteratorIterableDataset( + torch.utils.data.IterableDataset, Iterator, Stateful +): def __init__(self, samples, shuffle, include_generator): self.samples = samples self.shuffle = shuffle @@ -139,7 +141,10 @@ def __iter__(self): class DummyMapDataset(torch.utils.data.Dataset): def __init__(self, size, shuffle, include_generator=True): self.size = size - self.data = [{"id": i, "strcol": f"strcol_{i}", "listcol": [i, i + 1, i + 2]} for i in range(size)] + self.data = [ + {"id": i, "strcol": f"strcol_{i}", "listcol": [i, i + 1, i + 2]} + for i in range(size) + ] self.shuffle = shuffle self.include_generator = include_generator @@ -202,7 +207,9 @@ class TestStatefulDataLoaderIterable_shard0(TestCase): def _get_dataset(self, shuffle): return DummyIterableDataset([0, 100, 37], shuffle=shuffle) - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = self._get_dataset(shuffle) dl = StatefulDataLoader( dataset=dataset, @@ -270,7 +277,9 @@ def test_mp_pw(self): def test_mp_every_n_steps(self): batch_size = 7 for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): - with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt): + with self.subTest( + every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt + ): self._run_and_checkpoint( num_workers=3, batch_size=batch_size, @@ -291,7 +300,9 @@ def test_random_state(self): class TestStatefulDataLoaderMap_shard1(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): if num_workers == 0: return dataset = DummyMapDataset(100, shuffle=shuffle) @@ -344,7 +355,9 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st class TestStatefulSampler_shard1(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = DummyMapDataset(100, shuffle=shuffle) sampler = DummySampler(len(dataset)) dl = StatefulDataLoader( @@ -472,7 +485,9 @@ def load_state_dict(self, state): class TestStatefulDataLoaderGenerator_shard2(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = GeneratorIterable([0, 100, 37]) dl = StatefulDataLoader( dataset=dataset, @@ -521,8 +536,12 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st self.assertEqual(batches, exp) -class TestStatefulDataLoaderGeneratorNoState_shard2(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): +class TestStatefulDataLoaderGeneratorNoState_shard2( + TestStatefulDataLoaderIterable_shard0 +): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = GeneratorIterableNoState([0, 100, 37]) dl = StatefulDataLoader( dataset=dataset, @@ -582,7 +601,9 @@ def test_generator(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -605,7 +626,9 @@ def test_iterable(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -628,7 +651,9 @@ def test_map(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -652,7 +677,9 @@ def test_map_shuffle(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -667,7 +694,9 @@ def test_map_shuffle(self): def test_map_iterrupted_shuffle(self): every_n_steps = 10 - for pw, num_workers, every_n_steps in itertools.product([False, True], [0, 2], [1, 15]): + for pw, num_workers, every_n_steps in itertools.product( + [False, True], [0, 2], [1, 15] + ): dataset = DummyMapDataset(10, shuffle=True) dl = StatefulDataLoader( dataset=dataset, @@ -676,7 +705,9 @@ def test_map_iterrupted_shuffle(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw if num_workers > 0 else False, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -712,7 +743,9 @@ def test_generator(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) exp = list(dl) state_end = dl.state_dict() @@ -728,7 +761,9 @@ def test_generator(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) for _ in range(2): @@ -750,7 +785,9 @@ def test_generator_no_state(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) exp = list(dl) state_end = dl.state_dict() @@ -766,7 +803,9 @@ def test_generator_no_state(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) for _ in range(2): @@ -791,7 +830,9 @@ def test_iterable(self): persistent_workers=pw, batch_size=bs, generator=g, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) list(dl) state_end = dl.state_dict() @@ -806,7 +847,9 @@ def test_iterable(self): persistent_workers=pw, batch_size=bs, generator=g, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state_end) batches = list(dl) @@ -828,7 +871,9 @@ def test_map(self): persistent_workers=pw, batch_size=bs, generator=generator, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) list(dl) state_end = dl.state_dict() @@ -843,7 +888,9 @@ def test_map(self): persistent_workers=pw, batch_size=bs, generator=generator, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state_end) batches = list(dl) @@ -863,7 +910,9 @@ def test_map_shuffle(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) list(dl) state_end = dl.state_dict() @@ -878,7 +927,9 @@ def test_map_shuffle(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state_end) batches = list(dl) @@ -896,7 +947,9 @@ def test_num_workers_mismatch(self): dataset=dataset, num_workers=initial_num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and initial_num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and initial_num_workers else None + ), ) state = dl.state_dict() @@ -908,7 +961,9 @@ def test_num_workers_mismatch(self): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state) try: @@ -994,7 +1049,9 @@ def test_fast_state_dict_request_skip_steps(self) -> None: class TestJsonSerDe_shard3(TestCase): def _run_test_iterable(self, num_workers): interrupt = 4 - dataset = DummyIterableDataset([0, 100, 37], shuffle=False, include_generator=False) + dataset = DummyIterableDataset( + [0, 100, 37], shuffle=False, include_generator=False + ) dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, @@ -1256,7 +1313,9 @@ def test_load_then_state(self): class TestStatefulDataLoaderIterable2_shard0(TestStatefulDataLoaderIterable_shard0): # Perform sanity test checks with the iterable dataset that is also an iterator def _get_dataset(self, shuffle): - return DummyIteratorIterableDataset(list(range(100)), shuffle=shuffle, include_generator=True) + return DummyIteratorIterableDataset( + list(range(100)), shuffle=shuffle, include_generator=True + ) class TestDynamicStateIterableDataset_shard0(TestCase): @@ -1274,7 +1333,9 @@ def test(self): for _ in range((num_workers + 1) * 2): next(it) state_dict = dl.state_dict() - worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] + worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["dataset_iter_state"] self.assertEqual(len(worker_state), 7) deep_copy_state_dict = deepcopy(state_dict) @@ -1284,9 +1345,9 @@ def test(self): next_state_dict = dl.state_dict() self.assertEqual(state_dict, deep_copy_state_dict) self.assertFalse(state_dict == next_state_dict) - worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"][ - "dataset_iter_state" - ] + worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["dataset_iter_state"] self.assertEqual(len(worker_state), 11) dl = StatefulDataLoader( @@ -1302,19 +1363,25 @@ def test(self): exp.extend(next(it)) state_dict = dl.state_dict() self.assertEqual(exp, [3, 3]) - worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] + worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["dataset_iter_state"] self.assertEqual(len(worker_state), 9) class TestDatasetIteratorStateDuplication_shard0(TestCase): def test(self): - dataset = DummyIteratorIterableDataset(list(range(100)), shuffle=True, include_generator=True) + dataset = DummyIteratorIterableDataset( + list(range(100)), shuffle=True, include_generator=True + ) for num_workers in (0, 2): dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) # Fetch at least one batch from each worker @@ -1326,13 +1393,15 @@ def test(self): for i in range(num_workers): # Ensure worker state is stored only once if the dataset is also the iterator self.assertEqual( - state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"][ + "dataset_state" + ], None, ) self.assertTrue( - state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ - "dataset_iter_state" - ] + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"][ + "fetcher_state" + ]["dataset_iter_state"] ) else: self.assertEqual(state_dict["dataset_state"], None) @@ -1452,7 +1521,9 @@ def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) def _run(self, data_size, num_workers, batch_size, shuffle=False): @@ -1488,7 +1559,9 @@ def _run(self, data_size, num_workers, batch_size, shuffle=False): epoch_num_items_yielded += 1 additional_num_items_yielded += epoch_num_items_yielded # Check that the total number of items yielded is correct - self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) + self.assertEqual( + num_items_yielded + additional_num_items_yielded, data_size * 4 + ) # now run a second dataloder for 4 epochs and check if the order is same. dl2 = self.get_map_dl( @@ -1517,6 +1590,83 @@ def test_multiprocess_shuffle(self): self._run(100, 2, 1, True) +class TestEndOfEpochBehavior_shard0(TestCase): + def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): + dataset = DummyMapDataset(data_size, shuffle=False) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), + ) + + def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int: + num_items_yielded = 0 + for batch in data_loader: + num_items_yielded += 1 + return num_items_yielded + + def _run(self, data_size, num_workers, batch_size, shuffle=False): + dl = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + # Run through the dataloader for 1 epoch and count the number of items yielded + num_items_yielded = 0 + + for batch in dl: + num_items_yielded += 1 + sd_in = dl.state_dict() + sd_out = dl.state_dict() + + self.assertEqual(num_items_yielded, data_size) + + # Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch + dl_sd_in = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dl_sd_in.load_state_dict(sd_in) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be 0 since the state dict was saved before the end of epoch + num_items_yielded = self._count_items_yielded(dl_sd_in) + self.assertEqual(num_items_yielded, 0) + + # Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch + dl_sd_out = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dl_sd_out.load_state_dict(sd_out) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be data_size since the state dict was saved after the end of epoch + num_items_yielded = self._count_items_yielded(dl_sd_out) + self.assertEqual(num_items_yielded, data_size) + + def test_main_process(self): + self._run(100, 0, 1, False) + + def test_multiprocess(self): + self._run(100, 2, 1, False) + + def test_main_process_shuffle(self): + self._run(100, 0, 1, True) + + def test_multiprocess_shuffle(self): + self._run(100, 2, 1, True) + + class TestMultiEpochState_shard0(TestCase): def get_iterable_dl(self, pw, num_workers): data_size = [25, 50, 100, 75] @@ -1528,7 +1678,9 @@ def get_iterable_dl(self, pw, num_workers): num_workers=num_workers, persistent_workers=pw, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) def _run(self, pw: bool, num_workers: int): @@ -1589,7 +1741,9 @@ def __iter__(self): num_workers = torch.utils.data.get_worker_info().num_workers num_samples = (int)(self.length / num_workers) - self.iter_state = IterationState(num_samples * worker_id, num_samples * (worker_id + 1)) + self.iter_state = IterationState( + num_samples * worker_id, num_samples * (worker_id + 1) + ) return self def __next__(self): @@ -1615,29 +1769,39 @@ def _get_iter_calls(self, state): if w_states[0]["dataset_state"] is not None: return [x["dataset_state"]["iter_calls"] for x in w_states] - return [x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states] + return [ + x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states + ] def _run_test(self, num_workers, dataset, expected_iter_calls): dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) iter(dl) state = dl.state_dict() # Ensure iter is called only once per worker - self.assertEqual(self._get_iter_calls(state), [expected_iter_calls[0]] * max(1, num_workers)) + self.assertEqual( + self._get_iter_calls(state), [expected_iter_calls[0]] * max(1, num_workers) + ) dl2 = StatefulDataLoader( dataset=dataset, num_workers=num_workers, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl2.load_state_dict(state) iter(dl2) state2 = dl2.state_dict() # Ensure that iter is called only once per worker even when dataloader resumes from a state - self.assertEqual(self._get_iter_calls(state2), [expected_iter_calls[1]] * max(1, num_workers)) + self.assertEqual( + self._get_iter_calls(state2), [expected_iter_calls[1]] * max(1, num_workers) + ) def test_inline(self): self._run_test(0, CountIterCalls(100), [1, 2]) @@ -1678,7 +1842,9 @@ def _run_test(self, num_workers, dataset): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) data = [] @@ -1692,7 +1858,9 @@ def _run_test(self, num_workers, dataset): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl2.load_state_dict(state) it = iter(dl2) @@ -1739,7 +1907,9 @@ def give_data(self, iter_start, iter_end): def __iter__(self): worker_info = torch.utils.data.get_worker_info() - per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + per_worker = int( + math.ceil((self.end - self.start) / float(worker_info.num_workers)) + ) worker_id = worker_info.id iter_start = self.start + worker_id * per_worker iter_end = min(iter_start + per_worker, self.end) @@ -1793,12 +1963,18 @@ def test_out_of_order_iterable_ds_one_completed_worker(self): state_dict = dataloader.state_dict() break - worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] - worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["fetcher_ended"] + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"][ + "fetcher_state" + ]["fetcher_ended"] self.assertTrue(worker_0_ended) self.assertFalse(worker_1_ended) - new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) + new_dataloader = StatefulDataLoader( + dataset, batch_size=1, num_workers=2, in_order=False + ) new_dataloader.load_state_dict(state_dict) for i, data in enumerate(new_dataloader): output.append(data) @@ -1824,12 +2000,18 @@ def test_out_of_order_iterable_ds_no_completed_workers(self): state_dict = dataloader.state_dict() break - worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] - worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["fetcher_ended"] + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"][ + "fetcher_state" + ]["fetcher_ended"] self.assertFalse(worker_0_ended) self.assertFalse(worker_1_ended) - new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) + new_dataloader = StatefulDataLoader( + dataset, batch_size=1, num_workers=2, in_order=False + ) new_dataloader.load_state_dict(state_dict) for i, data in enumerate(new_dataloader): output.append(data) From 56c6882bc154276caef4cee3f283ac054e703b19 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 09:07:44 -0800 Subject: [PATCH 07/24] run precommit --- test/stateful_dataloader/test_state_dict.py | 225 +++++--------------- 1 file changed, 59 insertions(+), 166 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index f8dc86e60..9d90f381a 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -83,9 +83,7 @@ def __len__(self): return self.size -class DummyIteratorIterableDataset( - torch.utils.data.IterableDataset, Iterator, Stateful -): +class DummyIteratorIterableDataset(torch.utils.data.IterableDataset, Iterator, Stateful): def __init__(self, samples, shuffle, include_generator): self.samples = samples self.shuffle = shuffle @@ -141,10 +139,7 @@ def __iter__(self): class DummyMapDataset(torch.utils.data.Dataset): def __init__(self, size, shuffle, include_generator=True): self.size = size - self.data = [ - {"id": i, "strcol": f"strcol_{i}", "listcol": [i, i + 1, i + 2]} - for i in range(size) - ] + self.data = [{"id": i, "strcol": f"strcol_{i}", "listcol": [i, i + 1, i + 2]} for i in range(size)] self.shuffle = shuffle self.include_generator = include_generator @@ -207,9 +202,7 @@ class TestStatefulDataLoaderIterable_shard0(TestCase): def _get_dataset(self, shuffle): return DummyIterableDataset([0, 100, 37], shuffle=shuffle) - def _run_and_checkpoint( - self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False - ): + def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): dataset = self._get_dataset(shuffle) dl = StatefulDataLoader( dataset=dataset, @@ -277,9 +270,7 @@ def test_mp_pw(self): def test_mp_every_n_steps(self): batch_size = 7 for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): - with self.subTest( - every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt - ): + with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt): self._run_and_checkpoint( num_workers=3, batch_size=batch_size, @@ -300,9 +291,7 @@ def test_random_state(self): class TestStatefulDataLoaderMap_shard1(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint( - self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False - ): + def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): if num_workers == 0: return dataset = DummyMapDataset(100, shuffle=shuffle) @@ -355,9 +344,7 @@ def _run_and_checkpoint( class TestStatefulSampler_shard1(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint( - self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False - ): + def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): dataset = DummyMapDataset(100, shuffle=shuffle) sampler = DummySampler(len(dataset)) dl = StatefulDataLoader( @@ -485,9 +472,7 @@ def load_state_dict(self, state): class TestStatefulDataLoaderGenerator_shard2(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint( - self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False - ): + def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): dataset = GeneratorIterable([0, 100, 37]) dl = StatefulDataLoader( dataset=dataset, @@ -536,12 +521,8 @@ def _run_and_checkpoint( self.assertEqual(batches, exp) -class TestStatefulDataLoaderGeneratorNoState_shard2( - TestStatefulDataLoaderIterable_shard0 -): - def _run_and_checkpoint( - self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False - ): +class TestStatefulDataLoaderGeneratorNoState_shard2(TestStatefulDataLoaderIterable_shard0): + def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): dataset = GeneratorIterableNoState([0, 100, 37]) dl = StatefulDataLoader( dataset=dataset, @@ -601,9 +582,7 @@ def test_generator(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) @@ -626,9 +605,7 @@ def test_iterable(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) @@ -651,9 +628,7 @@ def test_map(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) @@ -677,9 +652,7 @@ def test_map_shuffle(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) @@ -694,9 +667,7 @@ def test_map_shuffle(self): def test_map_iterrupted_shuffle(self): every_n_steps = 10 - for pw, num_workers, every_n_steps in itertools.product( - [False, True], [0, 2], [1, 15] - ): + for pw, num_workers, every_n_steps in itertools.product([False, True], [0, 2], [1, 15]): dataset = DummyMapDataset(10, shuffle=True) dl = StatefulDataLoader( dataset=dataset, @@ -705,9 +676,7 @@ def test_map_iterrupted_shuffle(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw if num_workers > 0 else False, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) @@ -743,9 +712,7 @@ def test_generator(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) exp = list(dl) state_end = dl.state_dict() @@ -761,9 +728,7 @@ def test_generator(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) for _ in range(2): @@ -785,9 +750,7 @@ def test_generator_no_state(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) exp = list(dl) state_end = dl.state_dict() @@ -803,9 +766,7 @@ def test_generator_no_state(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) for _ in range(2): @@ -830,9 +791,7 @@ def test_iterable(self): persistent_workers=pw, batch_size=bs, generator=g, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) list(dl) state_end = dl.state_dict() @@ -847,9 +806,7 @@ def test_iterable(self): persistent_workers=pw, batch_size=bs, generator=g, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) dl.load_state_dict(state_end) batches = list(dl) @@ -871,9 +828,7 @@ def test_map(self): persistent_workers=pw, batch_size=bs, generator=generator, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) list(dl) state_end = dl.state_dict() @@ -888,9 +843,7 @@ def test_map(self): persistent_workers=pw, batch_size=bs, generator=generator, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) dl.load_state_dict(state_end) batches = list(dl) @@ -910,9 +863,7 @@ def test_map_shuffle(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) list(dl) state_end = dl.state_dict() @@ -927,9 +878,7 @@ def test_map_shuffle(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) dl.load_state_dict(state_end) batches = list(dl) @@ -947,9 +896,7 @@ def test_num_workers_mismatch(self): dataset=dataset, num_workers=initial_num_workers, collate_fn=identity, - multiprocessing_context=( - "forkserver" if IS_MACOS and initial_num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and initial_num_workers else None), ) state = dl.state_dict() @@ -961,9 +908,7 @@ def test_num_workers_mismatch(self): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) dl.load_state_dict(state) try: @@ -1049,9 +994,7 @@ def test_fast_state_dict_request_skip_steps(self) -> None: class TestJsonSerDe_shard3(TestCase): def _run_test_iterable(self, num_workers): interrupt = 4 - dataset = DummyIterableDataset( - [0, 100, 37], shuffle=False, include_generator=False - ) + dataset = DummyIterableDataset([0, 100, 37], shuffle=False, include_generator=False) dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, @@ -1313,9 +1256,7 @@ def test_load_then_state(self): class TestStatefulDataLoaderIterable2_shard0(TestStatefulDataLoaderIterable_shard0): # Perform sanity test checks with the iterable dataset that is also an iterator def _get_dataset(self, shuffle): - return DummyIteratorIterableDataset( - list(range(100)), shuffle=shuffle, include_generator=True - ) + return DummyIteratorIterableDataset(list(range(100)), shuffle=shuffle, include_generator=True) class TestDynamicStateIterableDataset_shard0(TestCase): @@ -1333,9 +1274,7 @@ def test(self): for _ in range((num_workers + 1) * 2): next(it) state_dict = dl.state_dict() - worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ - "fetcher_state" - ]["dataset_iter_state"] + worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] self.assertEqual(len(worker_state), 7) deep_copy_state_dict = deepcopy(state_dict) @@ -1345,9 +1284,9 @@ def test(self): next_state_dict = dl.state_dict() self.assertEqual(state_dict, deep_copy_state_dict) self.assertFalse(state_dict == next_state_dict) - worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ - "fetcher_state" - ]["dataset_iter_state"] + worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"][ + "dataset_iter_state" + ] self.assertEqual(len(worker_state), 11) dl = StatefulDataLoader( @@ -1363,25 +1302,19 @@ def test(self): exp.extend(next(it)) state_dict = dl.state_dict() self.assertEqual(exp, [3, 3]) - worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ - "fetcher_state" - ]["dataset_iter_state"] + worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] self.assertEqual(len(worker_state), 9) class TestDatasetIteratorStateDuplication_shard0(TestCase): def test(self): - dataset = DummyIteratorIterableDataset( - list(range(100)), shuffle=True, include_generator=True - ) + dataset = DummyIteratorIterableDataset(list(range(100)), shuffle=True, include_generator=True) for num_workers in (0, 2): dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) # Fetch at least one batch from each worker @@ -1393,15 +1326,13 @@ def test(self): for i in range(num_workers): # Ensure worker state is stored only once if the dataset is also the iterator self.assertEqual( - state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"][ - "dataset_state" - ], + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None, ) self.assertTrue( - state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"][ - "fetcher_state" - ]["dataset_iter_state"] + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ + "dataset_iter_state" + ] ) else: self.assertEqual(state_dict["dataset_state"], None) @@ -1521,9 +1452,7 @@ def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) def _run(self, data_size, num_workers, batch_size, shuffle=False): @@ -1559,9 +1488,7 @@ def _run(self, data_size, num_workers, batch_size, shuffle=False): epoch_num_items_yielded += 1 additional_num_items_yielded += epoch_num_items_yielded # Check that the total number of items yielded is correct - self.assertEqual( - num_items_yielded + additional_num_items_yielded, data_size * 4 - ) + self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) # now run a second dataloder for 4 epochs and check if the order is same. dl2 = self.get_map_dl( @@ -1598,9 +1525,7 @@ def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int: @@ -1678,9 +1603,7 @@ def get_iterable_dl(self, pw, num_workers): num_workers=num_workers, persistent_workers=pw, collate_fn=identity, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) def _run(self, pw: bool, num_workers: int): @@ -1741,9 +1664,7 @@ def __iter__(self): num_workers = torch.utils.data.get_worker_info().num_workers num_samples = (int)(self.length / num_workers) - self.iter_state = IterationState( - num_samples * worker_id, num_samples * (worker_id + 1) - ) + self.iter_state = IterationState(num_samples * worker_id, num_samples * (worker_id + 1)) return self def __next__(self): @@ -1769,39 +1690,29 @@ def _get_iter_calls(self, state): if w_states[0]["dataset_state"] is not None: return [x["dataset_state"]["iter_calls"] for x in w_states] - return [ - x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states - ] + return [x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states] def _run_test(self, num_workers, dataset, expected_iter_calls): dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) iter(dl) state = dl.state_dict() # Ensure iter is called only once per worker - self.assertEqual( - self._get_iter_calls(state), [expected_iter_calls[0]] * max(1, num_workers) - ) + self.assertEqual(self._get_iter_calls(state), [expected_iter_calls[0]] * max(1, num_workers)) dl2 = StatefulDataLoader( dataset=dataset, num_workers=num_workers, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) dl2.load_state_dict(state) iter(dl2) state2 = dl2.state_dict() # Ensure that iter is called only once per worker even when dataloader resumes from a state - self.assertEqual( - self._get_iter_calls(state2), [expected_iter_calls[1]] * max(1, num_workers) - ) + self.assertEqual(self._get_iter_calls(state2), [expected_iter_calls[1]] * max(1, num_workers)) def test_inline(self): self._run_test(0, CountIterCalls(100), [1, 2]) @@ -1842,9 +1753,7 @@ def _run_test(self, num_workers, dataset): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) data = [] @@ -1858,9 +1767,7 @@ def _run_test(self, num_workers, dataset): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=( - "forkserver" if IS_MACOS and num_workers else None - ), + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) dl2.load_state_dict(state) it = iter(dl2) @@ -1907,9 +1814,7 @@ def give_data(self, iter_start, iter_end): def __iter__(self): worker_info = torch.utils.data.get_worker_info() - per_worker = int( - math.ceil((self.end - self.start) / float(worker_info.num_workers)) - ) + per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) worker_id = worker_info.id iter_start = self.start + worker_id * per_worker iter_end = min(iter_start + per_worker, self.end) @@ -1963,18 +1868,12 @@ def test_out_of_order_iterable_ds_one_completed_worker(self): state_dict = dataloader.state_dict() break - worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ - "fetcher_state" - ]["fetcher_ended"] - worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"][ - "fetcher_state" - ]["fetcher_ended"] + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] self.assertTrue(worker_0_ended) self.assertFalse(worker_1_ended) - new_dataloader = StatefulDataLoader( - dataset, batch_size=1, num_workers=2, in_order=False - ) + new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) new_dataloader.load_state_dict(state_dict) for i, data in enumerate(new_dataloader): output.append(data) @@ -2000,18 +1899,12 @@ def test_out_of_order_iterable_ds_no_completed_workers(self): state_dict = dataloader.state_dict() break - worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ - "fetcher_state" - ]["fetcher_ended"] - worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"][ - "fetcher_state" - ]["fetcher_ended"] + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] self.assertFalse(worker_0_ended) self.assertFalse(worker_1_ended) - new_dataloader = StatefulDataLoader( - dataset, batch_size=1, num_workers=2, in_order=False - ) + new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) new_dataloader.load_state_dict(state_dict) for i, data in enumerate(new_dataloader): output.append(data) From 6f3abf6ac9256c1fa2e601f4387d4a1d8c9ab405 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 12:34:49 -0800 Subject: [PATCH 08/24] check if _sampler_iter is exhausted --- torchdata/stateful_dataloader/sampler.py | 17 +- .../stateful_dataloader.py | 249 +++++++++++++----- 2 files changed, 199 insertions(+), 67 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 308ccd7a8..9172b3e6d 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -107,17 +107,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): + if not ( + isinstance(self.sampler, Stateful) + or isinstance(self.sampler_iter, Stateful) + ) and not isinstance(self.sampler, _InfiniteConstantSampler): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) - - # Skip one epoch if we were at the end of the last epoch - if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler): - for _ in self.sampler_iter: - pass + # elif self.samples_yielded > 0: + # print("no fast forward, reset") + # # don't re-create sampler_iter unless necessary, we may already have one from init + # self.sampler_iter = iter(self.sampler) + # self.samples_yielded = 0 class BatchSampler(torch.utils.data.sampler.BatchSampler): diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 078b378ee..f43e23dd7 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -46,7 +46,10 @@ ) from torch.utils.data.dataloader import _BaseDataLoaderIter, _InfiniteConstantSampler -from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper +from torch.utils.data.datapipes.datapipe import ( + _IterDataPipeSerializationWrapper, + _MapDataPipeSerializationWrapper, +) from .incremental_state import ( _DATASET_ITER_STATE, @@ -215,7 +218,8 @@ def __init__( if num_workers < 0: raise ValueError( - "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing." + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." ) if timeout < 0: @@ -291,7 +295,9 @@ def __init__( # specific workers. if isinstance(dataset, IterDataPipe): if shuffle is not None: - dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) + dataset = torch.utils.data.graph_settings.apply_shuffle_settings( + dataset, shuffle=shuffle + ) # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. elif shuffle not in {False, None}: raise ValueError( @@ -320,7 +326,9 @@ def __init__( # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError( - "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last" + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" ) batch_size = None drop_last = False @@ -328,7 +336,8 @@ def __init__( # no auto_collation if drop_last: raise ValueError( - "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last" + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" ) if sampler is None: # give default samplers @@ -362,7 +371,9 @@ def __init__( # set DataLoader's __initialized attribute. self._DataLoader__initialized = True - self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) self._iterator = None @@ -473,13 +484,19 @@ def __init__(self, loader, next_iter_state=None): # Taking care of distributed sharding if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): # For BC, use default SHARDING_PRIORITIES - torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank) + torch.utils.data.graph_settings.apply_sharding( + self._dataset, self._world_size, self._rank + ) if next_iter_state is not None: self.load_state_dict(next_iter_state) else: self._dataset_fetcher = _DatasetKind.create_fetcher( - self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, ) def _next_data(self): @@ -492,7 +509,9 @@ def _next_data(self): def state_dict(self): if self._dataset_kind == _DatasetKind.Iterable: fetcher_state = { - _DATASET_ITER_STATE: try_to_serialize(self._dataset_fetcher.dataset_iter), + _DATASET_ITER_STATE: try_to_serialize( + self._dataset_fetcher.dataset_iter + ), _FETCHER_ENDED: self._dataset_fetcher.ended, } dataset_state = None @@ -522,15 +541,31 @@ def load_state_dict(self, state_dict): self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict - if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): - self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) + if isinstance(self._index_sampler, Stateful) or isinstance( + self._sampler_iter, Stateful + ): + self._index_sampler = try_to_deserialize( + self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] + ) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) + self._sampler_iter = try_to_deserialize( + self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] + ) + if state_dict[_ITERATOR_FINISHED]: + try: + next(self._sampler_iter) + except StopIteration: + pass else: - if not isinstance(self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler): + if not isinstance( + self._index_sampler, + torch.utils.data.dataloader._InfiniteConstantSampler, + ): # Fallback to fastforward - self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) + self._sampler_iter = itertools.islice( + self._index_sampler, self._sampler_iter_yielded, None + ) self._num_yielded = state_dict[self._NUM_YIELDED] self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] @@ -539,21 +574,33 @@ def load_state_dict(self, state_dict): # 1. try to restore dataset state # 2. generate dataset iterator # 3. try to restore iterator state - if state_dict[_DATASET_STATE] is not None and isinstance(self._dataset, Stateful): - self._dataset = try_to_deserialize(self._dataset, state_dict[_DATASET_STATE]) + if state_dict[_DATASET_STATE] is not None and isinstance( + self._dataset, Stateful + ): + self._dataset = try_to_deserialize( + self._dataset, state_dict[_DATASET_STATE] + ) self._dataset_fetcher = _DatasetKind.create_fetcher( - self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, ) if self._dataset_kind == _DatasetKind.Iterable: # If either dataset or it's iter is stateful, we don't fast-forward - if isinstance(self._dataset, Stateful) or isinstance(self._dataset_fetcher.dataset_iter, Stateful): + if isinstance(self._dataset, Stateful) or isinstance( + self._dataset_fetcher.dataset_iter, Stateful + ): if state_dict[_FETCHER_STATE] is not None: if state_dict[_FETCHER_STATE][_DATASET_ITER_STATE] is not None: self._dataset_fetcher.dataset_iter = try_to_deserialize( self._dataset_fetcher.dataset_iter, state_dict[_FETCHER_STATE][_DATASET_ITER_STATE], ) - self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][_FETCHER_ENDED] + self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][ + _FETCHER_ENDED + ] else: # No state, just try to fastforward if self._num_yielded > 0: @@ -907,7 +954,10 @@ def __init__(self, loader, next_iter_state): # Additional worker init function will take care of sharding in MP and Distributed if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): self._worker_init_fn = functools.partial( - _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, ) # No certainty which module multiprocessing_context is @@ -925,16 +975,20 @@ def __init__(self, loader, next_iter_state): self._SNAPSHOT in next_iter_state ), f"State doesn't contain key '{self._SNAPSHOT}' expected for multiprocess dataloader" wstates = next_iter_state[self._SNAPSHOT].get(self._WORKER_SNAPSHOTS, {}) - assert set(map(self._worker_key, range(len(wstates)))) == set(wstates.keys()), ( + assert set(map(self._worker_key, range(len(wstates)))) == set( + wstates.keys() + ), ( len(wstates), wstates.keys(), ) for worker_key, sd in wstates.items(): worker_states[worker_key] = sd - self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get(self._BASE_SEED, self._base_seed) - self._shared_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( - _SHARED_SEED, self._shared_seed + self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( + self._BASE_SEED, self._base_seed ) + self._shared_seed = next_iter_state[self._SNAPSHOT][ + self._MAIN_SNAPSHOT + ].get(_SHARED_SEED, self._shared_seed) for i in range(self._num_workers): # No certainty which module multiprocessing_context is @@ -982,7 +1036,9 @@ def __init__(self, loader, next_iter_state): if self._pin_memory_device == "xpu": current_device = torch.xpu.current_device() # type: ignore[attr-defined] elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): - custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + custom_device_mod = getattr( + torch, torch._C._get_privateuse1_backend_name() + ) current_device = custom_device_mod.current_device() else: current_device = torch.cuda.current_device() # choose cuda for default @@ -1015,7 +1071,9 @@ def __init__(self, loader, next_iter_state): import atexit for w in self._workers: - atexit.register(_StatefulMultiProcessingDataLoaderIter._clean_up_worker, w) + atexit.register( + _StatefulMultiProcessingDataLoaderIter._clean_up_worker, w + ) # .pid can be None only before process is spawned (not the case, so ignore) _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] @@ -1031,17 +1089,23 @@ def __init__(self, loader, next_iter_state): # We need to send initial worker state back to the main process to handle state_dict() requests # before n >= num_workers steps are taken. # self._worker_snapshots: Dict[str, _IncrementalWorkerState] = {} - self._worker_snapshots = {key: _IncrementalWorkerState(state) for key, state in worker_states.items()} + self._worker_snapshots = { + key: _IncrementalWorkerState(state) for key, state in worker_states.items() + } self._reset(loader, first_iter=True, prime_prefetch=next_iter_state is None) # Try to restore main state if next_iter_state is not None: - self._restore_main_state(next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT]) + self._restore_main_state( + next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT] + ) self._num_yielded = next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP] self._update_snapshot( snapshot_step=next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP], - last_yielded_worker_id=next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID], + last_yielded_worker_id=next_iter_state[self._SNAPSHOT][ + self._LAST_YIELDED_WORKER_ID + ], num_workers=self._num_workers, main_snapshot=next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT], worker_snapshots=self._worker_snapshots, @@ -1052,7 +1116,10 @@ def __init__(self, loader, next_iter_state): for state in worker_states.values(): if state is None: continue - if state[_DATASET_STATE] is None and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None: + if ( + state[_DATASET_STATE] is None + and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None + ): fast_forward = True break @@ -1069,10 +1136,17 @@ def __init__(self, loader, next_iter_state): for _ in range(self._num_yielded): next(self) # Check if last_yielded_worker_id matches - if self._last_yielded_worker_id != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID]: - raise ValueError("last_yielded_worker_id does not match, the dataset may have changed") + if ( + self._last_yielded_worker_id + != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] + ): + raise ValueError( + "last_yielded_worker_id does not match, the dataset may have changed" + ) else: - self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] + self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][ + self._LAST_YIELDED_WORKER_ID + ] for _ in range(self._last_yielded_worker_id + 1): next(self._worker_queue_idx_cycle) for _ in range(self._prefetch_factor * self._num_workers): @@ -1090,7 +1164,9 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} - self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset @@ -1115,7 +1191,9 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): while remaining > 0: _, data = self._get_data() if not all(self._workers_status): - raise ValueError(f"A worker has failed during startup! {self._workers_status}") + raise ValueError( + f"A worker has failed during startup! {self._workers_status}" + ) elif isinstance(data, _AckStartup): if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() @@ -1123,27 +1201,37 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): if data.is_delta: self._worker_snapshots[self._worker_key(data.worker_id)].apply_delta(data.initial_state) # type: ignore[arg-type] else: - self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( + self._worker_snapshots[ + self._worker_key(data.worker_id) + ] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) remaining -= 1 else: - raise ValueError(f"Invalid response from worker after startup: {data}") + raise ValueError( + f"Invalid response from worker after startup: {data}" + ) else: # We resume the prefetching in case it was enabled for idx in range(self._num_workers): - self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed)) + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: return_idx, data = self._get_data() if not all(self._workers_status): - raise ValueError(f"A worker has failed during Resume! {self._workers_status}") + raise ValueError( + f"A worker has failed during Resume! {self._workers_status}" + ) if isinstance(return_idx, _utils.worker._ResumeIteration): assert isinstance(data, _AckStartup), (return_idx, data) if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() assert data.initial_state is not None, data - self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( + self._worker_snapshots[ + self._worker_key(data.worker_id) + ] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) resume_iteration_cnt -= 1 @@ -1211,7 +1299,9 @@ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): self._mark_worker_as_unavailable(worker_id) if len(failed_workers) > 0: pids_str = ", ".join(str(w.pid) for w in failed_workers) - raise RuntimeError(f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly") from e + raise RuntimeError( + f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" + ) from e if isinstance(e, queue.Empty): return (False, None) import errno @@ -1349,7 +1439,9 @@ def _get_data(self): if success: return data else: - raise RuntimeError(f"DataLoader timed out after {self._timeout} seconds") + raise RuntimeError( + f"DataLoader timed out after {self._timeout} seconds" + ) elif self._pin_memory: while self._pin_memory_thread.is_alive(): success, data = self._try_get_data() @@ -1382,7 +1474,9 @@ def _next_data(self): info = self._task_info.get(self._rcvd_idx, None) if info: worker_id = info[0] - if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active break del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 @@ -1398,7 +1492,9 @@ def _next_data(self): if len(self._task_info[self._rcvd_idx]) == 2: data, worker_id, state_dict = self._task_info.pop(self._rcvd_idx)[1] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + self._update_worker_snapshot( + self._worker_key(data.worker_id), state_dict + ) self._rcvd_idx += 1 continue else: @@ -1415,7 +1511,9 @@ def _next_data(self): self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) - assert state_dict is not None, "StopIteration should always be accompanied by a state_dict" + assert ( + state_dict is not None + ), "StopIteration should always be accompanied by a state_dict" self._try_put_index() # We want to process states until we get to that position # in the worker cycle, therefore if out-of-order we want @@ -1426,7 +1524,9 @@ def _next_data(self): if not self._in_order: # don't store it for later, process now if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + self._update_worker_snapshot( + self._worker_key(data.worker_id), state_dict + ) continue del self._task_info[idx] return self._process_data(data, worker_id, state_dict) @@ -1434,7 +1534,9 @@ def _next_data(self): else: del self._task_info[idx] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) + self._update_worker_snapshot( + self._worker_key(data.worker_id), state_dict + ) self._rcvd_idx += 1 continue else: @@ -1456,15 +1558,26 @@ def _restore_main_state(self, state_dict): assert self._num_workers == state_dict[self._NUM_WORKERS] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] - if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): - self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) + if isinstance(self._index_sampler, Stateful) or isinstance( + self._sampler_iter, Stateful + ): + self._index_sampler = try_to_deserialize( + self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] + ) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) + self._sampler_iter = try_to_deserialize( + self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] + ) else: - if not isinstance(self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler): + if not isinstance( + self._index_sampler, + torch.utils.data.dataloader._InfiniteConstantSampler, + ): # Fallback to fastforward - self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) + self._sampler_iter = itertools.islice( + self._index_sampler, self._sampler_iter_yielded, None + ) self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] self._base_seed = state_dict[self._BASE_SEED] @@ -1501,7 +1614,9 @@ def _try_put_index(self): if self._workers_status[worker_queue_idx]: if self._in_order: break - elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status): + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum( + self._workers_status + ): # when self._in_order is False, distribute work to a worker if it has capacity # _workers_status is updated only in this thread, so the sum is guaranteed > 0 break @@ -1527,20 +1642,29 @@ def _process_data(self, data, worker_id, state_dict): self._last_yielded_worker_id = worker_id # Update latest worker state if state_dict is not None: - self._update_worker_snapshot(self._worker_key(state_dict[_WORKER_ID]), state_dict) - if self._snapshot_interval and ((self._num_yielded + 1) % self._snapshot_interval == 0): + self._update_worker_snapshot( + self._worker_key(state_dict[_WORKER_ID]), state_dict + ) + if self._snapshot_interval and ( + (self._num_yielded + 1) % self._snapshot_interval == 0 + ): self._take_snapshot() return data def _take_snapshot(self): main_snapshot_idx = None - while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1): + while len(self._main_snapshots) and ( + self._main_snapshots[0][0] <= self._rcvd_idx - 1 + ): main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() if not self._in_order and main_snapshot_idx is None: # in_order is False and no main snapshot is available as we're ahead of rcvd_idx # we can't take a snapshot with the current implementation return - assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) + assert main_snapshot_idx == self._rcvd_idx - 1, ( + main_snapshot_idx, + self._rcvd_idx - 1, + ) self._update_snapshot( self._num_yielded + 1, self._last_yielded_worker_id, @@ -1561,7 +1685,10 @@ def _update_snapshot( self._SNAPSHOT_STEP: snapshot_step, self._LAST_YIELDED_WORKER_ID: last_yielded_worker_id, self._MAIN_SNAPSHOT: main_snapshot, - self._WORKER_SNAPSHOTS: {key: worker_state.get_state() for key, worker_state in worker_snapshots.items()}, + self._WORKER_SNAPSHOTS: { + key: worker_state.get_state() + for key, worker_state in worker_snapshots.items() + }, } def _mark_worker_as_unavailable(self, worker_id, shutdown=False): @@ -1594,7 +1721,11 @@ def _shutdown_workers(self): # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. - if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. From 36c5b51b9981e9632c4ebbc5684c3ffa79ef851b Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 12:36:31 -0800 Subject: [PATCH 09/24] run precommit --- torchdata/stateful_dataloader/sampler.py | 7 +- .../stateful_dataloader.py | 212 +++++------------- 2 files changed, 56 insertions(+), 163 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 9172b3e6d..bddef097e 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -107,10 +107,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not ( - isinstance(self.sampler, Stateful) - or isinstance(self.sampler_iter, Stateful) - ) and not isinstance(self.sampler, _InfiniteConstantSampler): + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler + ): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index f43e23dd7..c986a6dfa 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -46,10 +46,7 @@ ) from torch.utils.data.dataloader import _BaseDataLoaderIter, _InfiniteConstantSampler -from torch.utils.data.datapipes.datapipe import ( - _IterDataPipeSerializationWrapper, - _MapDataPipeSerializationWrapper, -) +from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper from .incremental_state import ( _DATASET_ITER_STATE, @@ -218,8 +215,7 @@ def __init__( if num_workers < 0: raise ValueError( - "num_workers option should be non-negative; " - "use num_workers=0 to disable multiprocessing." + "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing." ) if timeout < 0: @@ -295,9 +291,7 @@ def __init__( # specific workers. if isinstance(dataset, IterDataPipe): if shuffle is not None: - dataset = torch.utils.data.graph_settings.apply_shuffle_settings( - dataset, shuffle=shuffle - ) + dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. elif shuffle not in {False, None}: raise ValueError( @@ -326,9 +320,7 @@ def __init__( # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError( - "batch_sampler option is mutually exclusive " - "with batch_size, shuffle, sampler, and " - "drop_last" + "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last" ) batch_size = None drop_last = False @@ -336,8 +328,7 @@ def __init__( # no auto_collation if drop_last: raise ValueError( - "batch_size=None option disables auto-batching " - "and is mutually exclusive with drop_last" + "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last" ) if sampler is None: # give default samplers @@ -371,9 +362,7 @@ def __init__( # set DataLoader's __initialized attribute. self._DataLoader__initialized = True - self._IterableDataset_len_called = ( - None # See NOTE [ IterableDataset and __len__ ] - ) + self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] self._iterator = None @@ -484,9 +473,7 @@ def __init__(self, loader, next_iter_state=None): # Taking care of distributed sharding if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): # For BC, use default SHARDING_PRIORITIES - torch.utils.data.graph_settings.apply_sharding( - self._dataset, self._world_size, self._rank - ) + torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank) if next_iter_state is not None: self.load_state_dict(next_iter_state) @@ -509,9 +496,7 @@ def _next_data(self): def state_dict(self): if self._dataset_kind == _DatasetKind.Iterable: fetcher_state = { - _DATASET_ITER_STATE: try_to_serialize( - self._dataset_fetcher.dataset_iter - ), + _DATASET_ITER_STATE: try_to_serialize(self._dataset_fetcher.dataset_iter), _FETCHER_ENDED: self._dataset_fetcher.ended, } dataset_state = None @@ -541,17 +526,11 @@ def load_state_dict(self, state_dict): self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict - if isinstance(self._index_sampler, Stateful) or isinstance( - self._sampler_iter, Stateful - ): - self._index_sampler = try_to_deserialize( - self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] - ) + if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): + self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize( - self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] - ) + self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) if state_dict[_ITERATOR_FINISHED]: try: next(self._sampler_iter) @@ -563,9 +542,7 @@ def load_state_dict(self, state_dict): torch.utils.data.dataloader._InfiniteConstantSampler, ): # Fallback to fastforward - self._sampler_iter = itertools.islice( - self._index_sampler, self._sampler_iter_yielded, None - ) + self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._num_yielded = state_dict[self._NUM_YIELDED] self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] @@ -574,12 +551,8 @@ def load_state_dict(self, state_dict): # 1. try to restore dataset state # 2. generate dataset iterator # 3. try to restore iterator state - if state_dict[_DATASET_STATE] is not None and isinstance( - self._dataset, Stateful - ): - self._dataset = try_to_deserialize( - self._dataset, state_dict[_DATASET_STATE] - ) + if state_dict[_DATASET_STATE] is not None and isinstance(self._dataset, Stateful): + self._dataset = try_to_deserialize(self._dataset, state_dict[_DATASET_STATE]) self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, @@ -589,18 +562,14 @@ def load_state_dict(self, state_dict): ) if self._dataset_kind == _DatasetKind.Iterable: # If either dataset or it's iter is stateful, we don't fast-forward - if isinstance(self._dataset, Stateful) or isinstance( - self._dataset_fetcher.dataset_iter, Stateful - ): + if isinstance(self._dataset, Stateful) or isinstance(self._dataset_fetcher.dataset_iter, Stateful): if state_dict[_FETCHER_STATE] is not None: if state_dict[_FETCHER_STATE][_DATASET_ITER_STATE] is not None: self._dataset_fetcher.dataset_iter = try_to_deserialize( self._dataset_fetcher.dataset_iter, state_dict[_FETCHER_STATE][_DATASET_ITER_STATE], ) - self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][ - _FETCHER_ENDED - ] + self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][_FETCHER_ENDED] else: # No state, just try to fastforward if self._num_yielded > 0: @@ -975,20 +944,16 @@ def __init__(self, loader, next_iter_state): self._SNAPSHOT in next_iter_state ), f"State doesn't contain key '{self._SNAPSHOT}' expected for multiprocess dataloader" wstates = next_iter_state[self._SNAPSHOT].get(self._WORKER_SNAPSHOTS, {}) - assert set(map(self._worker_key, range(len(wstates)))) == set( - wstates.keys() - ), ( + assert set(map(self._worker_key, range(len(wstates)))) == set(wstates.keys()), ( len(wstates), wstates.keys(), ) for worker_key, sd in wstates.items(): worker_states[worker_key] = sd - self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( - self._BASE_SEED, self._base_seed + self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get(self._BASE_SEED, self._base_seed) + self._shared_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( + _SHARED_SEED, self._shared_seed ) - self._shared_seed = next_iter_state[self._SNAPSHOT][ - self._MAIN_SNAPSHOT - ].get(_SHARED_SEED, self._shared_seed) for i in range(self._num_workers): # No certainty which module multiprocessing_context is @@ -1036,9 +1001,7 @@ def __init__(self, loader, next_iter_state): if self._pin_memory_device == "xpu": current_device = torch.xpu.current_device() # type: ignore[attr-defined] elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): - custom_device_mod = getattr( - torch, torch._C._get_privateuse1_backend_name() - ) + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) current_device = custom_device_mod.current_device() else: current_device = torch.cuda.current_device() # choose cuda for default @@ -1071,9 +1034,7 @@ def __init__(self, loader, next_iter_state): import atexit for w in self._workers: - atexit.register( - _StatefulMultiProcessingDataLoaderIter._clean_up_worker, w - ) + atexit.register(_StatefulMultiProcessingDataLoaderIter._clean_up_worker, w) # .pid can be None only before process is spawned (not the case, so ignore) _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] @@ -1089,23 +1050,17 @@ def __init__(self, loader, next_iter_state): # We need to send initial worker state back to the main process to handle state_dict() requests # before n >= num_workers steps are taken. # self._worker_snapshots: Dict[str, _IncrementalWorkerState] = {} - self._worker_snapshots = { - key: _IncrementalWorkerState(state) for key, state in worker_states.items() - } + self._worker_snapshots = {key: _IncrementalWorkerState(state) for key, state in worker_states.items()} self._reset(loader, first_iter=True, prime_prefetch=next_iter_state is None) # Try to restore main state if next_iter_state is not None: - self._restore_main_state( - next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT] - ) + self._restore_main_state(next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT]) self._num_yielded = next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP] self._update_snapshot( snapshot_step=next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP], - last_yielded_worker_id=next_iter_state[self._SNAPSHOT][ - self._LAST_YIELDED_WORKER_ID - ], + last_yielded_worker_id=next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID], num_workers=self._num_workers, main_snapshot=next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT], worker_snapshots=self._worker_snapshots, @@ -1116,10 +1071,7 @@ def __init__(self, loader, next_iter_state): for state in worker_states.values(): if state is None: continue - if ( - state[_DATASET_STATE] is None - and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None - ): + if state[_DATASET_STATE] is None and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None: fast_forward = True break @@ -1136,17 +1088,10 @@ def __init__(self, loader, next_iter_state): for _ in range(self._num_yielded): next(self) # Check if last_yielded_worker_id matches - if ( - self._last_yielded_worker_id - != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] - ): - raise ValueError( - "last_yielded_worker_id does not match, the dataset may have changed" - ) + if self._last_yielded_worker_id != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID]: + raise ValueError("last_yielded_worker_id does not match, the dataset may have changed") else: - self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][ - self._LAST_YIELDED_WORKER_ID - ] + self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] for _ in range(self._last_yielded_worker_id + 1): next(self._worker_queue_idx_cycle) for _ in range(self._prefetch_factor * self._num_workers): @@ -1164,9 +1109,7 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} - self._tasks_outstanding = ( - 0 # always equal to count(v for v in task_info.values() if len(v) == 1) - ) + self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1) # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset @@ -1191,9 +1134,7 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): while remaining > 0: _, data = self._get_data() if not all(self._workers_status): - raise ValueError( - f"A worker has failed during startup! {self._workers_status}" - ) + raise ValueError(f"A worker has failed during startup! {self._workers_status}") elif isinstance(data, _AckStartup): if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() @@ -1201,37 +1142,27 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): if data.is_delta: self._worker_snapshots[self._worker_key(data.worker_id)].apply_delta(data.initial_state) # type: ignore[arg-type] else: - self._worker_snapshots[ - self._worker_key(data.worker_id) - ] = _IncrementalWorkerState( + self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) remaining -= 1 else: - raise ValueError( - f"Invalid response from worker after startup: {data}" - ) + raise ValueError(f"Invalid response from worker after startup: {data}") else: # We resume the prefetching in case it was enabled for idx in range(self._num_workers): - self._index_queues[idx].put( - _utils.worker._ResumeIteration(self._shared_seed) - ) + self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed)) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: return_idx, data = self._get_data() if not all(self._workers_status): - raise ValueError( - f"A worker has failed during Resume! {self._workers_status}" - ) + raise ValueError(f"A worker has failed during Resume! {self._workers_status}") if isinstance(return_idx, _utils.worker._ResumeIteration): assert isinstance(data, _AckStartup), (return_idx, data) if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() assert data.initial_state is not None, data - self._worker_snapshots[ - self._worker_key(data.worker_id) - ] = _IncrementalWorkerState( + self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) resume_iteration_cnt -= 1 @@ -1299,9 +1230,7 @@ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): self._mark_worker_as_unavailable(worker_id) if len(failed_workers) > 0: pids_str = ", ".join(str(w.pid) for w in failed_workers) - raise RuntimeError( - f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" - ) from e + raise RuntimeError(f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly") from e if isinstance(e, queue.Empty): return (False, None) import errno @@ -1439,9 +1368,7 @@ def _get_data(self): if success: return data else: - raise RuntimeError( - f"DataLoader timed out after {self._timeout} seconds" - ) + raise RuntimeError(f"DataLoader timed out after {self._timeout} seconds") elif self._pin_memory: while self._pin_memory_thread.is_alive(): success, data = self._try_get_data() @@ -1474,9 +1401,7 @@ def _next_data(self): info = self._task_info.get(self._rcvd_idx, None) if info: worker_id = info[0] - if ( - len(info) == 2 or self._workers_status[worker_id] - ): # has data or is still active + if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active break del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 @@ -1492,9 +1417,7 @@ def _next_data(self): if len(self._task_info[self._rcvd_idx]) == 2: data, worker_id, state_dict = self._task_info.pop(self._rcvd_idx)[1] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot( - self._worker_key(data.worker_id), state_dict - ) + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) self._rcvd_idx += 1 continue else: @@ -1511,9 +1434,7 @@ def _next_data(self): self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) - assert ( - state_dict is not None - ), "StopIteration should always be accompanied by a state_dict" + assert state_dict is not None, "StopIteration should always be accompanied by a state_dict" self._try_put_index() # We want to process states until we get to that position # in the worker cycle, therefore if out-of-order we want @@ -1524,9 +1445,7 @@ def _next_data(self): if not self._in_order: # don't store it for later, process now if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot( - self._worker_key(data.worker_id), state_dict - ) + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) continue del self._task_info[idx] return self._process_data(data, worker_id, state_dict) @@ -1534,9 +1453,7 @@ def _next_data(self): else: del self._task_info[idx] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot( - self._worker_key(data.worker_id), state_dict - ) + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) self._rcvd_idx += 1 continue else: @@ -1558,26 +1475,18 @@ def _restore_main_state(self, state_dict): assert self._num_workers == state_dict[self._NUM_WORKERS] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] - if isinstance(self._index_sampler, Stateful) or isinstance( - self._sampler_iter, Stateful - ): - self._index_sampler = try_to_deserialize( - self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] - ) + if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): + self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize( - self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] - ) + self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) else: if not isinstance( self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler, ): # Fallback to fastforward - self._sampler_iter = itertools.islice( - self._index_sampler, self._sampler_iter_yielded, None - ) + self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] self._base_seed = state_dict[self._BASE_SEED] @@ -1614,9 +1523,7 @@ def _try_put_index(self): if self._workers_status[worker_queue_idx]: if self._in_order: break - elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum( - self._workers_status - ): + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status): # when self._in_order is False, distribute work to a worker if it has capacity # _workers_status is updated only in this thread, so the sum is guaranteed > 0 break @@ -1642,20 +1549,14 @@ def _process_data(self, data, worker_id, state_dict): self._last_yielded_worker_id = worker_id # Update latest worker state if state_dict is not None: - self._update_worker_snapshot( - self._worker_key(state_dict[_WORKER_ID]), state_dict - ) - if self._snapshot_interval and ( - (self._num_yielded + 1) % self._snapshot_interval == 0 - ): + self._update_worker_snapshot(self._worker_key(state_dict[_WORKER_ID]), state_dict) + if self._snapshot_interval and ((self._num_yielded + 1) % self._snapshot_interval == 0): self._take_snapshot() return data def _take_snapshot(self): main_snapshot_idx = None - while len(self._main_snapshots) and ( - self._main_snapshots[0][0] <= self._rcvd_idx - 1 - ): + while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1): main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() if not self._in_order and main_snapshot_idx is None: # in_order is False and no main snapshot is available as we're ahead of rcvd_idx @@ -1685,10 +1586,7 @@ def _update_snapshot( self._SNAPSHOT_STEP: snapshot_step, self._LAST_YIELDED_WORKER_ID: last_yielded_worker_id, self._MAIN_SNAPSHOT: main_snapshot, - self._WORKER_SNAPSHOTS: { - key: worker_state.get_state() - for key, worker_state in worker_snapshots.items() - }, + self._WORKER_SNAPSHOTS: {key: worker_state.get_state() for key, worker_state in worker_snapshots.items()}, } def _mark_worker_as_unavailable(self, worker_id, shutdown=False): @@ -1721,11 +1619,7 @@ def _shutdown_workers(self): # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. - if ( - _utils is None - or _utils.python_exit_status is True - or _utils.python_exit_status is None - ): + if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. From b9f194d884930a1dc7b35b1cf0208100f42c9f12 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 12:40:05 -0800 Subject: [PATCH 10/24] remove commented lines --- torchdata/stateful_dataloader/sampler.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index bddef097e..5fc232973 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -113,11 +113,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) - # elif self.samples_yielded > 0: - # print("no fast forward, reset") - # # don't re-create sampler_iter unless necessary, we may already have one from init - # self.sampler_iter = iter(self.sampler) - # self.samples_yielded = 0 class BatchSampler(torch.utils.data.sampler.BatchSampler): From eb95deb2a7f8eba6757e5989d0b9892bc233c51a Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 14:10:37 -0800 Subject: [PATCH 11/24] remove default values --- test/stateful_dataloader/test_state_dict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 9d90f381a..3edd126ad 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -1445,7 +1445,7 @@ def test_fast_state_dict_request_skip_steps(self) -> None: class TestMultiEpochSDL_shard0(TestCase): - def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): + def get_map_dl(self, data_size, num_workers, batch_size, shuffle): dataset = DummyMapDataset(data_size, shuffle=False) return StatefulDataLoader( dataset=dataset, @@ -1455,7 +1455,7 @@ def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) - def _run(self, data_size, num_workers, batch_size, shuffle=False): + def _run(self, data_size, num_workers, batch_size, shuffle): dl1 = self.get_map_dl( data_size=data_size, num_workers=num_workers, @@ -1518,7 +1518,7 @@ def test_multiprocess_shuffle(self): class TestEndOfEpochBehavior_shard0(TestCase): - def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): + def get_map_dl(self, data_size, num_workers, batch_size, shuffle): dataset = DummyMapDataset(data_size, shuffle=False) return StatefulDataLoader( dataset=dataset, @@ -1534,7 +1534,7 @@ def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int: num_items_yielded += 1 return num_items_yielded - def _run(self, data_size, num_workers, batch_size, shuffle=False): + def _run(self, data_size, num_workers, batch_size, shuffle): dl = self.get_map_dl( data_size=data_size, num_workers=num_workers, From d783247008f35861c274ee54e84076b00ecf434d Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 17:18:38 -0800 Subject: [PATCH 12/24] only exhaust sampler_iter if present in sd --- torchdata/stateful_dataloader/stateful_dataloader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index c986a6dfa..012646e1b 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -531,11 +531,11 @@ def load_state_dict(self, state_dict): self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) - if state_dict[_ITERATOR_FINISHED]: - try: - next(self._sampler_iter) - except StopIteration: - pass + if state_dict[_ITERATOR_FINISHED]: + try: + next(self._sampler_iter) + except StopIteration: + pass else: if not isinstance( self._index_sampler, From 6d49b4f393de401cb272eb6ec55f306d981d1f3d Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Thu, 6 Feb 2025 12:00:37 -0800 Subject: [PATCH 13/24] update _StatefulRandomSamplerIterator update state dict if the iterator has finished add comment about why were updating state dict run precommit --- torchdata/stateful_dataloader/sampler.py | 8 +++++++- torchdata/stateful_dataloader/stateful_dataloader.py | 11 ++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 5fc232973..64d2ed06a 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -32,7 +32,6 @@ def __next__(self) -> int: self.yielded = self.next_yielded self.next_yielded = None - val = next(self.parent_iterator) self.yielded += 1 return val @@ -42,6 +41,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.sampler.generator.set_state(state_dict[self._GENERATOR]) self.next_yielded = state_dict[self._YIELDED] + def update_state_dict(self) -> None: + self.generator_state = self.sampler.generator.get_state() + def state_dict(self) -> Dict[str, Any]: return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} @@ -114,6 +116,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: for _ in range(self.samples_yielded): next(self.sampler_iter) + def update_state_dict(self) -> None: + if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): + self.sampler_iter.update_state_dict() + class BatchSampler(torch.utils.data.sampler.BatchSampler): def __init__(self, sampler, batch_size, drop_last): diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 012646e1b..680a3e23c 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -449,6 +449,11 @@ def __next__(self): try: return super().__next__() except StopIteration: + # If we are at the end of the iteration, we want to update the state dict of _sampler_iter. + # because in __iter__ after self._iterator is set using self._get_iterator() [which makes self.next_iter_state = None], + # it is checked if self._iterator._finished is True, and if it is, self._iterator is reset with next_iter_state = None. + if hasattr(self._sampler_iter, "update_state_dict"): + self._sampler_iter.update_state_dict() self._finished = True raise @@ -531,11 +536,7 @@ def load_state_dict(self, state_dict): self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) - if state_dict[_ITERATOR_FINISHED]: - try: - next(self._sampler_iter) - except StopIteration: - pass + else: if not isinstance( self._index_sampler, From 20a14e5f2fbe4b65d490ae69d2d1c1faf355d6b1 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Fri, 7 Feb 2025 08:55:44 -0800 Subject: [PATCH 14/24] update randomsampleriter state_dict fully --- torchdata/stateful_dataloader/sampler.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 64d2ed06a..203a4e042 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -43,6 +43,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def update_state_dict(self) -> None: self.generator_state = self.sampler.generator.get_state() + self.yielded = 0 def state_dict(self) -> Dict[str, Any]: return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} @@ -109,15 +110,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): + if not ( + isinstance(self.sampler, Stateful) + or isinstance(self.sampler_iter, Stateful) + ) and not isinstance(self.sampler, _InfiniteConstantSampler): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) def update_state_dict(self) -> None: - if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): + if isinstance(self.sampler_iter, Stateful) and hasattr( + self.sampler_iter, "update_state_dict" + ): self.sampler_iter.update_state_dict() From 093e5f251104a4ab93ec15f0f1b004c4567d91c4 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Fri, 7 Feb 2025 08:57:58 -0800 Subject: [PATCH 15/24] run precommit --- torchdata/stateful_dataloader/sampler.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 203a4e042..f8fb2339f 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -110,18 +110,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not ( - isinstance(self.sampler, Stateful) - or isinstance(self.sampler_iter, Stateful) - ) and not isinstance(self.sampler, _InfiniteConstantSampler): + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler + ): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) def update_state_dict(self) -> None: - if isinstance(self.sampler_iter, Stateful) and hasattr( - self.sampler_iter, "update_state_dict" - ): + if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): self.sampler_iter.update_state_dict() From 39995a32b032e7cee49b7c49bc5ec887131efd57 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 10 Feb 2025 20:45:45 -0800 Subject: [PATCH 16/24] fork torch.utils.data RandomSampler reverse changes to sdl.py generator to iterator run precommit update generator usage --- torchdata/stateful_dataloader/sampler.py | 142 ++++++++++++++---- .../stateful_dataloader.py | 6 - 2 files changed, 113 insertions(+), 35 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index f8fb2339f..c9b644d1b 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -10,61 +10,145 @@ import torch.utils.data.sampler from torch.utils.data import Dataset from torch.utils.data.dataloader import _InfiniteConstantSampler +from torch.utils.data.sampler import Sampler from .stateful import Stateful -class _StatefulRandomSamplerIterator(Iterator[int], Stateful): +class StatefulRandomSamplerIterator(Iterator[int], Stateful): _GENERATOR = "generator" _YIELDED = "yielded" - def __init__(self, sampler, parent_iterator: Iterator[int]): + def __init__(self, sampler): self.sampler = sampler - self.parent_iterator = parent_iterator + self.generator_state = self.sampler.generator.get_state() self.yielded = 0 self.next_yielded = None - self.generator_state = sampler.generator.get_state() + self.n = len(sampler.data_source) + self.replacement = sampler.replacement + self.num_samples = sampler.num_samples + self.chunk_size = 32 + self.chunk_index = 0 + self.perm_index = 0 + self.perm = None - def __next__(self) -> int: - if self.next_yielded is not None: - for _ in range(self.next_yielded): - next(self.parent_iterator) + def __iter__(self): + return self + + def __next__(self): + if self.replacement: + num_full_chunks = self.num_samples // self.chunk_size + remainder = self.num_samples % self.chunk_size + if self.chunk_index < num_full_chunks: + if self.perm is None or not self.perm: + self.perm = torch.randint( + high=self.n, + size=(self.chunk_size,), + dtype=torch.int64, + generator=self.sampler.generator, + ).tolist() + self.perm_index = 0 + value = self.perm[self.perm_index] + self.perm_index += 1 + if self.perm_index == self.chunk_size: + self.chunk_index += 1 + self.perm = None + self.yielded += 1 + return value + elif remainder > 0: + if self.perm is None or not self.perm: + self.perm = torch.randint( + high=self.n, + size=(remainder,), + dtype=torch.int64, + generator=self.sampler.generator, + ).tolist() + self.perm_index = 0 + value = self.perm[self.perm_index] + self.perm_index += 1 + if self.perm_index == remainder: + raise StopIteration + self.yielded += 1 + return value + else: + raise StopIteration + else: + num_full_perms = self.num_samples // self.n + remainder = self.num_samples % self.n + if self.chunk_index < num_full_perms: + if self.perm is None or not self.perm: + self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist() + self.perm_index = 0 + value = self.perm[self.perm_index] + self.perm_index += 1 + if self.perm_index == self.n: + self.chunk_index += 1 + self.perm = None + self.yielded += 1 + return value + elif remainder > 0: + if self.perm is None or not self.perm: + self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder] + self.perm_index = 0 + value = self.perm[self.perm_index] + self.perm_index += 1 + if self.perm_index == remainder: + raise StopIteration + self.yielded += 1 + return value + else: + raise StopIteration - self.yielded = self.next_yielded - self.next_yielded = None - val = next(self.parent_iterator) - self.yielded += 1 - return val + def state_dict(self) -> dict: + return { + self._YIELDED: self.yielded, + self._GENERATOR: self.generator_state, + } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self.generator_state = state_dict[self._GENERATOR] - self.sampler.generator.set_state(state_dict[self._GENERATOR]) + def load_state_dict(self, state_dict: dict) -> None: self.next_yielded = state_dict[self._YIELDED] - - def update_state_dict(self) -> None: - self.generator_state = self.sampler.generator.get_state() - self.yielded = 0 - - def state_dict(self) -> Dict[str, Any]: - return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} + self.generator_state = state_dict[self._GENERATOR] + self.sampler.generator.set_state(self.generator_state) + if self.next_yielded is not None: + for _ in range(self.next_yielded - self.yielded): + next(self) + self.yielded = self.next_yielded + self.next_yielded = None -class RandomSampler(torch.utils.data.sampler.RandomSampler): +class RandomSampler(Sampler[int]): def __init__( self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None, - ): + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples if generator is None: # Ensure that underlying sampler has something repeatable generator = torch.Generator() generator.manual_seed(1) - super().__init__(data_source, replacement, num_samples, generator) - - def __iter__(self): - return _StatefulRandomSamplerIterator(self, super().__iter__()) + self.generator = generator + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + return StatefulRandomSamplerIterator(self) + + def __len__(self) -> int: + return self.num_samples class _BatchSamplerIterator(Iterator[list[int]], Stateful): diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 680a3e23c..1ffeec298 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -449,11 +449,6 @@ def __next__(self): try: return super().__next__() except StopIteration: - # If we are at the end of the iteration, we want to update the state dict of _sampler_iter. - # because in __iter__ after self._iterator is set using self._get_iterator() [which makes self.next_iter_state = None], - # it is checked if self._iterator._finished is True, and if it is, self._iterator is reset with next_iter_state = None. - if hasattr(self._sampler_iter, "update_state_dict"): - self._sampler_iter.update_state_dict() self._finished = True raise @@ -536,7 +531,6 @@ def load_state_dict(self, state_dict): self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) - else: if not isinstance( self._index_sampler, From 1ac45db601c1b17a83e5ca6d21373bc872aab86e Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 10 Feb 2025 22:15:04 -0800 Subject: [PATCH 17/24] update class name --- torchdata/stateful_dataloader/sampler.py | 31 ++++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index c9b644d1b..dd2ddb6aa 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -15,7 +15,7 @@ from .stateful import Stateful -class StatefulRandomSamplerIterator(Iterator[int], Stateful): +class _StatefulRandomSamplerIterator(Iterator[int], Stateful): _GENERATOR = "generator" _YIELDED = "yielded" @@ -77,7 +77,9 @@ def __next__(self): remainder = self.num_samples % self.n if self.chunk_index < num_full_perms: if self.perm is None or not self.perm: - self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist() + self.perm = torch.randperm( + self.n, generator=self.sampler.generator + ).tolist() self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -88,7 +90,9 @@ def __next__(self): return value elif remainder > 0: if self.perm is None or not self.perm: - self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder] + self.perm = torch.randperm( + self.n, generator=self.sampler.generator + ).tolist()[:remainder] self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -133,9 +137,13 @@ def __init__( generator.manual_seed(1) self.generator = generator if not isinstance(self.replacement, bool): - raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + raise TypeError( + f"replacement should be a boolean value, but got replacement={self.replacement}" + ) if not isinstance(self.num_samples, int) or self.num_samples <= 0: - raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" + ) @property def num_samples(self) -> int: @@ -145,7 +153,7 @@ def num_samples(self) -> int: return self._num_samples def __iter__(self) -> Iterator[int]: - return StatefulRandomSamplerIterator(self) + return _StatefulRandomSamplerIterator(self) def __len__(self) -> int: return self.num_samples @@ -194,15 +202,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): + if not ( + isinstance(self.sampler, Stateful) + or isinstance(self.sampler_iter, Stateful) + ) and not isinstance(self.sampler, _InfiniteConstantSampler): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) def update_state_dict(self) -> None: - if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): + if isinstance(self.sampler_iter, Stateful) and hasattr( + self.sampler_iter, "update_state_dict" + ): self.sampler_iter.update_state_dict() From 5167a94d02b226e196a2c149667298d44b654108 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 10 Feb 2025 22:30:39 -0800 Subject: [PATCH 18/24] run precommit --- torchdata/stateful_dataloader/sampler.py | 27 +++++++----------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index dd2ddb6aa..0c4164976 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -77,9 +77,7 @@ def __next__(self): remainder = self.num_samples % self.n if self.chunk_index < num_full_perms: if self.perm is None or not self.perm: - self.perm = torch.randperm( - self.n, generator=self.sampler.generator - ).tolist() + self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist() self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -90,9 +88,7 @@ def __next__(self): return value elif remainder > 0: if self.perm is None or not self.perm: - self.perm = torch.randperm( - self.n, generator=self.sampler.generator - ).tolist()[:remainder] + self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder] self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -137,13 +133,9 @@ def __init__( generator.manual_seed(1) self.generator = generator if not isinstance(self.replacement, bool): - raise TypeError( - f"replacement should be a boolean value, but got replacement={self.replacement}" - ) + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") if not isinstance(self.num_samples, int) or self.num_samples <= 0: - raise ValueError( - f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" - ) + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") @property def num_samples(self) -> int: @@ -202,18 +194,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not ( - isinstance(self.sampler, Stateful) - or isinstance(self.sampler_iter, Stateful) - ) and not isinstance(self.sampler, _InfiniteConstantSampler): + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler + ): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) def update_state_dict(self) -> None: - if isinstance(self.sampler_iter, Stateful) and hasattr( - self.sampler_iter, "update_state_dict" - ): + if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): self.sampler_iter.update_state_dict() From 34dc402fc992bf0db74b73206e3c7a2dd32cf728 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 11 Feb 2025 13:46:37 -0800 Subject: [PATCH 19/24] add a method to generate permutations --- torchdata/stateful_dataloader/sampler.py | 37 ++++++++++++------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 0c4164976..45aaefb11 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -35,18 +35,24 @@ def __init__(self, sampler): def __iter__(self): return self + def _get_perm(self, replacement: bool, num_samples: int) -> List[int]: + if replacement: + return torch.randint( + high=self.n, + size=(num_samples,), + dtype=torch.int64, + generator=self.sampler.generator, + ).tolist() + else: + return torch.randperm(self.n, generator=self.sampler.generator).tolist()[:num_samples] + def __next__(self): if self.replacement: num_full_chunks = self.num_samples // self.chunk_size remainder = self.num_samples % self.chunk_size if self.chunk_index < num_full_chunks: - if self.perm is None or not self.perm: - self.perm = torch.randint( - high=self.n, - size=(self.chunk_size,), - dtype=torch.int64, - generator=self.sampler.generator, - ).tolist() + if self.perm is None: + self.perm = self._get_perm(self.replacement, self.chunk_size) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -56,13 +62,8 @@ def __next__(self): self.yielded += 1 return value elif remainder > 0: - if self.perm is None or not self.perm: - self.perm = torch.randint( - high=self.n, - size=(remainder,), - dtype=torch.int64, - generator=self.sampler.generator, - ).tolist() + if self.perm is None: + self.perm = self._get_perm(self.replacement, remainder) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -76,8 +77,8 @@ def __next__(self): num_full_perms = self.num_samples // self.n remainder = self.num_samples % self.n if self.chunk_index < num_full_perms: - if self.perm is None or not self.perm: - self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist() + if self.perm is None: + self.perm = self._get_perm(self.replacement, self.n) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -87,8 +88,8 @@ def __next__(self): self.yielded += 1 return value elif remainder > 0: - if self.perm is None or not self.perm: - self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder] + if self.perm is None: + self.perm = self._get_perm(self.replacement, remainder) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 From 2d3da6cd8f0a413ca2d8a59e551745911f2a57fa Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 11 Feb 2025 13:50:58 -0800 Subject: [PATCH 20/24] update return type --- torchdata/stateful_dataloader/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 45aaefb11..6de94edbb 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Any, Dict, Iterator, Optional, Sized +from typing import Any, Dict, Iterator, List, Optional, Sized import torch.utils.data.sampler from torch.utils.data import Dataset From 53f37a0ef6c4b65b2869d1e58eb04aaebb844536 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 11 Feb 2025 16:10:52 -0800 Subject: [PATCH 21/24] update next logic --- torchdata/stateful_dataloader/sampler.py | 75 ++++++------------------ 1 file changed, 17 insertions(+), 58 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 6de94edbb..cd13c1f2d 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -28,77 +28,34 @@ def __init__(self, sampler): self.replacement = sampler.replacement self.num_samples = sampler.num_samples self.chunk_size = 32 - self.chunk_index = 0 + self.perm: List[int] = self._get_perm() self.perm_index = 0 - self.perm = None + self.chunk_index = 0 def __iter__(self): return self - def _get_perm(self, replacement: bool, num_samples: int) -> List[int]: - if replacement: + def _get_perm(self) -> List[int]: + if self.replacement: return torch.randint( high=self.n, - size=(num_samples,), + size=(self.chunk_size,), dtype=torch.int64, generator=self.sampler.generator, ).tolist() else: - return torch.randperm(self.n, generator=self.sampler.generator).tolist()[:num_samples] + return torch.randperm(self.n, generator=self.sampler.generator).tolist() def __next__(self): - if self.replacement: - num_full_chunks = self.num_samples // self.chunk_size - remainder = self.num_samples % self.chunk_size - if self.chunk_index < num_full_chunks: - if self.perm is None: - self.perm = self._get_perm(self.replacement, self.chunk_size) - self.perm_index = 0 - value = self.perm[self.perm_index] - self.perm_index += 1 - if self.perm_index == self.chunk_size: - self.chunk_index += 1 - self.perm = None - self.yielded += 1 - return value - elif remainder > 0: - if self.perm is None: - self.perm = self._get_perm(self.replacement, remainder) - self.perm_index = 0 - value = self.perm[self.perm_index] - self.perm_index += 1 - if self.perm_index == remainder: - raise StopIteration - self.yielded += 1 - return value - else: - raise StopIteration - else: - num_full_perms = self.num_samples // self.n - remainder = self.num_samples % self.n - if self.chunk_index < num_full_perms: - if self.perm is None: - self.perm = self._get_perm(self.replacement, self.n) - self.perm_index = 0 - value = self.perm[self.perm_index] - self.perm_index += 1 - if self.perm_index == self.n: - self.chunk_index += 1 - self.perm = None - self.yielded += 1 - return value - elif remainder > 0: - if self.perm is None: - self.perm = self._get_perm(self.replacement, remainder) - self.perm_index = 0 - value = self.perm[self.perm_index] - self.perm_index += 1 - if self.perm_index == remainder: - raise StopIteration - self.yielded += 1 - return value - else: - raise StopIteration + if self.yielded == self.num_samples: + raise StopIteration() + if self.perm_index == len(self.perm): + self.perm = self._get_perm() + self.perm_index = 0 + val = self.perm[self.perm_index] + self.perm_index += 1 + self.yielded += 1 + return val def state_dict(self) -> dict: return { @@ -110,7 +67,9 @@ def load_state_dict(self, state_dict: dict) -> None: self.next_yielded = state_dict[self._YIELDED] self.generator_state = state_dict[self._GENERATOR] self.sampler.generator.set_state(self.generator_state) + if self.next_yielded is not None: + self.perm = self._get_perm() for _ in range(self.next_yielded - self.yielded): next(self) self.yielded = self.next_yielded From da55b343d5aecbda5e12d5abaf248b711879156d Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 11 Feb 2025 16:19:31 -0800 Subject: [PATCH 22/24] add comment --- torchdata/stateful_dataloader/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index cd13c1f2d..d0ae1a3eb 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -69,8 +69,8 @@ def load_state_dict(self, state_dict: dict) -> None: self.sampler.generator.set_state(self.generator_state) if self.next_yielded is not None: - self.perm = self._get_perm() - for _ in range(self.next_yielded - self.yielded): + self.perm = self._get_perm() # We want permutations from the latest generator state that's loaded + for _ in range(self.next_yielded): next(self) self.yielded = self.next_yielded self.next_yielded = None From 5fa27b82aa1f2f9434a14c18ea203d65a4bae377 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Thu, 13 Feb 2025 15:23:50 -0800 Subject: [PATCH 23/24] update tests to include non stateful samplers --- test/stateful_dataloader/test_state_dict.py | 138 ++++++++++++++------ 1 file changed, 95 insertions(+), 43 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 3edd126ad..d53c26fad 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -15,6 +15,8 @@ import torch import torch.utils.data + +from parameterized import parameterized from torch.testing._internal.common_utils import IS_MACOS, TEST_CUDA, TestCase from torchdata.stateful_dataloader import Stateful, StatefulDataLoader @@ -1456,7 +1458,7 @@ def get_map_dl(self, data_size, num_workers, batch_size, shuffle): ) def _run(self, data_size, num_workers, batch_size, shuffle): - dl1 = self.get_map_dl( + dataloader1 = self.get_map_dl( data_size=data_size, num_workers=num_workers, batch_size=batch_size, @@ -1464,57 +1466,49 @@ def _run(self, data_size, num_workers, batch_size, shuffle): ) # Run through the dataloader for 2 epochs and count the number of items yielded num_items_yielded = 0 - dl1_items = [] + dataloader1_items = [] for _ in range(2): - for batch in dl1: - dl1_items.append(batch) + for batch in dataloader1: + dataloader1_items.append(batch) num_items_yielded += 1 # Save the state dict - state_dict = dl1.state_dict() + state_dict = dataloader1.state_dict() # Create a new StatefulDataLoader instance and load the state dict - new_dl1 = self.get_map_dl( + new_dataloader1 = self.get_map_dl( data_size=data_size, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, ) - new_dl1.load_state_dict(state_dict) + new_dataloader1.load_state_dict(state_dict) # Run through the new dataloader for another 2 epochs and count the number of items yielded additional_num_items_yielded = 0 for i in range(2): epoch_num_items_yielded = 0 - for batch in new_dl1: - dl1_items.append(batch) + for batch in new_dataloader1: + dataloader1_items.append(batch) epoch_num_items_yielded += 1 additional_num_items_yielded += epoch_num_items_yielded # Check that the total number of items yielded is correct self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) # now run a second dataloder for 4 epochs and check if the order is same. - dl2 = self.get_map_dl( + dataloader2 = self.get_map_dl( data_size=data_size, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, ) - dl2_items = [] + dataloader2_items = [] for _ in range(4): - for batch in dl2: - dl2_items.append(batch) - - self.assertEqual(dl1_items, dl2_items) - - def test_main_process(self): - self._run(100, 0, 1, False) - - def test_multiprocess(self): - self._run(100, 2, 1, False) + for batch in dataloader2: + dataloader2_items.append(batch) - def test_main_process_shuffle(self): - self._run(100, 0, 1, True) + self.assertEqual(dataloader1_items, dataloader2_items) - def test_multiprocess_shuffle(self): - self._run(100, 2, 1, True) + @parameterized.expand(itertools.product([100], [0, 2], [1], [False, True])) + def test_multi_epoch_sdl(self, datasize, num_workers, batch_size, shuffle): + self._run(datasize, num_workers, batch_size, shuffle) class TestEndOfEpochBehavior_shard0(TestCase): @@ -1535,7 +1529,7 @@ def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int: return num_items_yielded def _run(self, data_size, num_workers, batch_size, shuffle): - dl = self.get_map_dl( + dataloader = self.get_map_dl( data_size=data_size, num_workers=num_workers, batch_size=batch_size, @@ -1544,52 +1538,110 @@ def _run(self, data_size, num_workers, batch_size, shuffle): # Run through the dataloader for 1 epoch and count the number of items yielded num_items_yielded = 0 - for batch in dl: + for batch in dataloader: num_items_yielded += 1 - sd_in = dl.state_dict() - sd_out = dl.state_dict() + sd_in = dataloader.state_dict() + sd_out = dataloader.state_dict() self.assertEqual(num_items_yielded, data_size) # Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch - dl_sd_in = self.get_map_dl( + dataloader_sd_in = self.get_map_dl( data_size=data_size, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, ) - dl_sd_in.load_state_dict(sd_in) + dataloader_sd_in.load_state_dict(sd_in) # Run through the new dataloader for 1 epoch and count the number of items yielded # num_items_yielded should be 0 since the state dict was saved before the end of epoch - num_items_yielded = self._count_items_yielded(dl_sd_in) + num_items_yielded = self._count_items_yielded(dataloader_sd_in) self.assertEqual(num_items_yielded, 0) # Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch - dl_sd_out = self.get_map_dl( + dataloader_sd_out = self.get_map_dl( data_size=data_size, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, ) - dl_sd_out.load_state_dict(sd_out) + dataloader_sd_out.load_state_dict(sd_out) # Run through the new dataloader for 1 epoch and count the number of items yielded # num_items_yielded should be data_size since the state dict was saved after the end of epoch - num_items_yielded = self._count_items_yielded(dl_sd_out) + num_items_yielded = self._count_items_yielded(dataloader_sd_out) self.assertEqual(num_items_yielded, data_size) - def test_main_process(self): - self._run(100, 0, 1, False) + @parameterized.expand(itertools.product([100], [0, 2], [1], [False, True])) + def test_end_of_epoch_behavior(self, datasize, num_workers, batch_size, shuffle): + self._run(datasize, num_workers, batch_size, shuffle) + + +class TestNotStatefulSamplerSDL_shard0(TestCase): + def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls): + dataset = DummyMapDataset(data_size, shuffle=False) + sampler = sampler_cls(dataset) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=sampler, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + ) + + def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls): + torch.manual_seed(0) + dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + # interrupt the dataloader after interrupt batches and save the state dict + results_dataloader1 = [] + for i, batch in enumerate(dataloader1): + results_dataloader1.append(batch) + if i == interrupt: + break + state_dict = dataloader1.state_dict() - def test_multiprocess(self): - self._run(100, 2, 1, False) + torch.manual_seed(0) + resumed_dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + resumed_dataloader1.load_state_dict(state_dict) - def test_main_process_shuffle(self): - self._run(100, 0, 1, True) + for batch in resumed_dataloader1: + results_dataloader1.append(batch) - def test_multiprocess_shuffle(self): - self._run(100, 2, 1, True) + # now start a completely new dataloader and get all the batches + torch.manual_seed(0) + dataloader2 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + results_dataloader2 = [] + for batch in dataloader2: + results_dataloader2.append(batch) + self.assertEqual(results_dataloader1, results_dataloader2) + + @parameterized.expand( + itertools.product( + [100], + [0, 2], + [1], + [10, 50, 80], + [torch.utils.data.RandomSampler, torch.utils.data.SequentialSampler], + ) + ) + def test_notstatefulSDL(self, data_size, num_workers, batch_size, interrupt, sampler_cls): + self._run(100, 0, 1, interrupt, sampler_cls) class TestMultiEpochState_shard0(TestCase): From 0bdd8c26c68d9718dc0fed7ddcf033933087b1ab Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Thu, 13 Feb 2025 15:26:20 -0800 Subject: [PATCH 24/24] add comments --- test/stateful_dataloader/test_state_dict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index d53c26fad..01bb17ed5 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -1591,7 +1591,7 @@ def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls): ) def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls): - torch.manual_seed(0) + torch.manual_seed(0) # Fixing seed for deterministic results dataloader1 = self.get_map_dl( data_size=data_size, num_workers=num_workers, @@ -1606,7 +1606,9 @@ def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls): break state_dict = dataloader1.state_dict() - torch.manual_seed(0) + torch.manual_seed( + 0 + ) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before resumed_dataloader1 = self.get_map_dl( data_size=data_size, num_workers=num_workers,