-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix end of epoch StatefulDataLoader restart #1439
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1439
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit da55b34 with merge base fe6b405 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This does not solve the problem as it just restarts the dataloader and produces the same batches again. |
if its any help, while creating the issue i noticed that after loading the state dict, the resulting state dict in the dataloader is different from the one that was loaded - for example, by setting "samples_yielded" to 0 when the loaded one had 100 (see the prints in #1437 ), (and possibly more differences - I haven't checked). looking at the code in this commit, it seems that samples_yielded is being set manually - maybe that is the root of the problem? |
update stateful_dataloader run precommit local changes update test to test the order of batches update test update tests revert changes in SDL revert changes in SDL update tests run precommit
8136e63
to
a074b50
Compare
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler): | ||
for _ in self.sampler_iter: | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, sorry to nitpick, but wouldn't we actually not want to skip an epoch when reloading a state dict from an unfinished epoch? i.e. IMO it makes sense to recognise the difference between a state_dict obtained via this process:
for b in dl:
sd_in = dl.state_dict() # when the for ends, sd_in will describe an "about-to-finish" state
vs a state_dict obtained via this process:
for b in dl:
pass
sd_out = dl.state_dict() # sd_out will describe a "just finished" state
I think it makes sense to have an empty epoch immediately after loading sd_in, but a full one immediately after loading sd_out.
In particular, is it possible that issue #1437 is solved just by the new line self.samples_yielded = state_dict[self._SAMPLES_YIELDED]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gailweiss
I think the code changes take care of that. In the state_dict
['_iterator_finished']: False or True depending on if it is just- finishing or actually-finished.
When we load sd_in, the next epoch is still empty because state_dict
['_iterator_finished']: False
Is this the behavior you expect :
from torchdata.stateful_dataloader import StatefulDataLoader
def get_dl():
d = list(range(10))
return StatefulDataLoader(d, batch_size=1, shuffle=True)
dl = get_dl()
for i, b in enumerate(dl):
if i == 0:
print(i, b)
sd_in = dl.state_dict()
print("sd_in", sd_in)
dl = get_dl()
dl.load_state_dict(sd_in) # load the "about-to-finish" state
batches_after_sdin_load = []
for i, b in enumerate(dl):
batches_after_sdin_load.append(b)
if i == 0:
print(i, b)
print("batches_after_sdin_load", batches_after_sdin_load)
dl = get_dl()
for i, b in enumerate(dl):
if i == 0:
print(i, b)
sd_out = (
dl.state_dict()
) # when the for ends, sd_out will describe a "just-finished" state
print("sd_out", sd_out)
dl = get_dl()
dl.load_state_dict(sd_out) # load the "about-to-finish" state
batches_after_sdout_load = []
for i, b in enumerate(dl):
batches_after_sdout_load.append(b)
if i == 0:
print(i, b)
print("batches_after_sdout_load", batches_after_sdout_load)
Output:
0 tensor([5])
sd_in {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0, ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': False}
batches_after_sdin_load []
0 tensor([5])
sd_out {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0, ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
0 tensor([2])
batches_after_sdout_load [tensor([2]), tensor([8]), tensor([1]), tensor([5]), tensor([6]), tensor([9]), tensor([3]), tensor([7]), tensor([0]), tensor([4])]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, yes, that's great! Thank you and sorry for the false complaint!
@@ -1441,6 +1444,154 @@ def test_fast_state_dict_request_skip_steps(self) -> None: | |||
self._run_test(17, 19) | |||
|
|||
|
|||
class TestMultiEpochSDL_shard0(TestCase): | |||
def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you test this for num_workers=0 and say num_workers=2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed defaults values for num_workers
.
Tests below run for both 0 and 2.
next(self.sampler_iter) | ||
|
||
# Skip one epoch if we were at the end of the last epoch | ||
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do all samplers have __len__
? This feels brittle to me
for _ in self.sampler_iter: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If state_dict is saved at the end of epoch 1/start of epoch 2, why do we need to do this? What if the sampler is not stateful, would it happen twice due to lines 114? Something seems a bit off to me
@andrewkho |
update state dict if the iterator has finished add comment about why were updating state dict run precommit
4de1bb4
to
6d49b4f
Compare
TLDR: After refactoring BatchSampler, the same batch sequence is repeated in the epoch following a reload due to _iterator_finished being True. Update the generator in the state_dict after each iteration to cache the latest state, ensuring RNG resumes correctly even if next_yielded is reset to 0. Problem: Root Cause: Proposed Solution: |
For future ref: In torch.utils.data RandomSampler, we are changing the state of the generator even if |
def update_state_dict(self) -> None: | ||
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): | ||
self.sampler_iter.update_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving thoughts here:
- this introduces a new (undocumented) API for samplers (update_state_dict()) which isn't great
- If the alternative solution is to update the base sampler code, this means users bringing their own samplers won't work out of the box, which also isn't great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
self._dataset, | ||
self._auto_collation, | ||
self._collate_fn, | ||
self._drop_last, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we remove the whitespace changes (ie run pre-commit to format all these files) so it's easier to review?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have run precommit and these whitespaces are actually added by precommit
. If I revert back everything I get lint errors.
@@ -42,13 +41,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |||
self.sampler.generator.set_state(state_dict[self._GENERATOR]) | |||
self.next_yielded = state_dict[self._YIELDED] | |||
|
|||
def update_state_dict(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ramanishsingh can you try to remove this change, and fork the random sampler code here and make it an iterator instead of a generator? It will take a bit of refactoring but should make the save/load state more explicit, easier to reason about and handle.
If that works, check that the old RandomSampler works too as a drop-in replacement (it shoudl just fast-forward) to ensure users bringing their own (non-stateful) samplers will have things work as well
reverse changes to sdl.py generator to iterator run precommit update generator usage
0a90c04
to
39995a3
Compare
Hi, I tried to be clever and implement a temporary workaround myself, but no luck :) Code:
output:
|
@gailweiss can you try these examples using the code on this branch? I guess the code in this branch should be working fine (except for the random generator thing, which is taken care of in #1441). |
if self.chunk_index < num_full_chunks: | ||
if self.perm is None or not self.perm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we doing this chunking logic? I don't think this will give the correct distribution, ie it will depend on num_chunks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @andrewkho ,
that's how they are doing it in https://github.com/pytorch/pytorch/blob/b18e3c01aa8bfee37078b4a06cef4361bf63b36b/torch/utils/data/sampler.py#L177
Generate 32 random numbers at once and yield from them.
In their case self.num_samples // 32
is essentially num_full_chunks.
if self.perm is None or not self.perm: | ||
self.perm = torch.randint( | ||
high=self.n, | ||
size=(remainder,), | ||
dtype=torch.int64, | ||
generator=self.sampler.generator, | ||
).tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do implement with less repeated code and if/else branches? I think we need next to be as lightweight as possible.
I would suggest we generate the sequence once (ie at startup/load_state_dict or at the beginning of the function).
eg
self.seq = generate_sequence(replacement=True)
ret = self.seq[self.i]
self.i += 1
return ret
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the chunking logic present in random sampler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, chunking is present in the torch.utils.data RandomSampler.
@ramanishsingh sure thing! It seems better, but not all fixed. And sadly still breaks if the dataloaders are initiated with random generators. New output for same code:
|
Add tests to reproduce and fix #1437