-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix end of epoch StatefulDataLoader restart #1439
base: main
Are you sure you want to change the base?
Changes from 15 commits
1c69775
a074b50
50271b4
6ba9d94
1288e77
297d7bf
56c6882
6f3abf6
36c5b51
b9f194d
eb95deb
d783247
6d49b4f
20a14e5
093e5f2
39995a3
1ac45db
5167a94
34dc402
2d3da6c
53f37a0
da55b34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,6 @@ def __next__(self) -> int: | |
|
||
self.yielded = self.next_yielded | ||
self.next_yielded = None | ||
|
||
val = next(self.parent_iterator) | ||
self.yielded += 1 | ||
return val | ||
|
@@ -42,13 +41,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
self.sampler.generator.set_state(state_dict[self._GENERATOR]) | ||
self.next_yielded = state_dict[self._YIELDED] | ||
|
||
def update_state_dict(self) -> None: | ||
self.generator_state = self.sampler.generator.get_state() | ||
self.yielded = 0 | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} | ||
|
||
|
||
class RandomSampler(torch.utils.data.sampler.RandomSampler): | ||
def __init__( | ||
self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None | ||
self, | ||
data_source: Sized, | ||
replacement: bool = False, | ||
num_samples: Optional[int] = None, | ||
generator=None, | ||
): | ||
if generator is None: | ||
# Ensure that underlying sampler has something repeatable | ||
|
@@ -60,16 +67,30 @@ def __iter__(self): | |
return _StatefulRandomSamplerIterator(self, super().__iter__()) | ||
|
||
|
||
class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful): | ||
class _BatchSamplerIterator(Iterator[list[int]], Stateful): | ||
_SAMPLES_YIELDED = "samples_yielded" | ||
_SAMPLER_STATE = "sampler_state" | ||
_SAMPLER_ITER_STATE = "sampler_iter_state" | ||
|
||
def __init__(self, sampler, batch_size, drop_last): | ||
super().__init__(sampler, batch_size, drop_last) | ||
def __init__(self, sampler, batch_size: int, drop_last: bool): | ||
self.sampler = sampler | ||
self.sampler_iter = iter(self.sampler) | ||
self.batch_size = batch_size | ||
self.drop_last = drop_last | ||
self.samples_yielded = 0 | ||
self.next_yielded = None | ||
self.sampler_iter = iter(sampler) | ||
|
||
def __next__(self) -> list[int]: | ||
batch = [] | ||
try: | ||
for _ in range(self.batch_size): | ||
batch.append(next(self.sampler_iter)) | ||
self.samples_yielded += 1 | ||
return batch | ||
except StopIteration: | ||
if self.drop_last or len(batch) == 0: | ||
raise StopIteration | ||
else: | ||
return batch | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
sd: Dict[str, Any] = {self._SAMPLES_YIELDED: self.samples_yielded} | ||
|
@@ -80,7 +101,7 @@ def state_dict(self) -> Dict[str, Any]: | |
return sd | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
self.next_yielded = state_dict[self._SAMPLES_YIELDED] | ||
self.samples_yielded = state_dict[self._SAMPLES_YIELDED] | ||
if self._SAMPLER_STATE in state_dict: | ||
assert isinstance(self.sampler, Stateful) | ||
self.sampler.load_state_dict(state_dict[self._SAMPLER_STATE]) | ||
|
@@ -89,44 +110,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
assert isinstance(self.sampler_iter, Stateful) | ||
self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) | ||
|
||
if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( | ||
self.sampler, _InfiniteConstantSampler | ||
): | ||
# We skip x samples if underlying sampler is not stateful | ||
for _ in range(self.samples_yielded): | ||
next(self.sampler_iter) | ||
|
||
def update_state_dict(self) -> None: | ||
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): | ||
self.sampler_iter.update_state_dict() | ||
Comment on lines
+164
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaving thoughts here:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi, sorry to nitpick, but wouldn't we actually not want to skip an epoch when reloading a state dict from an unfinished epoch? i.e. IMO it makes sense to recognise the difference between a state_dict obtained via this process:
vs a state_dict obtained via this process:
I think it makes sense to have an empty epoch immediately after loading sd_in, but a full one immediately after loading sd_out. In particular, is it possible that issue #1437 is solved just by the new line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @gailweiss Is this the behavior you expect : from torchdata.stateful_dataloader import StatefulDataLoader
def get_dl():
d = list(range(10))
return StatefulDataLoader(d, batch_size=1, shuffle=True)
dl = get_dl()
for i, b in enumerate(dl):
if i == 0:
print(i, b)
sd_in = dl.state_dict()
print("sd_in", sd_in)
dl = get_dl()
dl.load_state_dict(sd_in) # load the "about-to-finish" state
batches_after_sdin_load = []
for i, b in enumerate(dl):
batches_after_sdin_load.append(b)
if i == 0:
print(i, b)
print("batches_after_sdin_load", batches_after_sdin_load)
dl = get_dl()
for i, b in enumerate(dl):
if i == 0:
print(i, b)
sd_out = (
dl.state_dict()
) # when the for ends, sd_out will describe a "just-finished" state
print("sd_out", sd_out)
dl = get_dl()
dl.load_state_dict(sd_out) # load the "about-to-finish" state
batches_after_sdout_load = []
for i, b in enumerate(dl):
batches_after_sdout_load.append(b)
if i == 0:
print(i, b)
print("batches_after_sdout_load", batches_after_sdout_load) Output:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, yes, that's great! Thank you and sorry for the false complaint! |
||
|
||
class BatchSampler(torch.utils.data.sampler.BatchSampler): | ||
def __init__(self, sampler, batch_size, drop_last): | ||
super().__init__(sampler, batch_size, drop_last) | ||
|
||
def __iter__(self): | ||
if self.next_yielded is not None: | ||
self.samples_yielded = self.next_yielded | ||
if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( | ||
self.sampler, _InfiniteConstantSampler | ||
): | ||
# We skip x samples if underlying sampler is not stateful | ||
for _ in range(self.next_yielded): | ||
next(self.sampler_iter) | ||
self.next_yielded = None | ||
elif self.samples_yielded > 0: | ||
# don't re-create sampler_iter unless necessary, we may already have one from init | ||
self.sampler_iter = iter(self.sampler) | ||
self.samples_yielded = 0 | ||
|
||
if self.drop_last: | ||
while True: | ||
try: | ||
batch = [] | ||
for _ in range(self.batch_size): | ||
batch.append(next(self.sampler_iter)) | ||
self.samples_yielded += 1 | ||
yield batch | ||
except StopIteration: | ||
break | ||
else: | ||
batch = [0] * self.batch_size | ||
idx_in_batch = 0 | ||
for idx in self.sampler_iter: | ||
self.samples_yielded += 1 | ||
batch[idx_in_batch] = idx | ||
idx_in_batch += 1 | ||
if idx_in_batch == self.batch_size: | ||
yield batch | ||
idx_in_batch = 0 | ||
batch = [0] * self.batch_size | ||
if idx_in_batch > 0: | ||
yield batch[:idx_in_batch] | ||
return _BatchSamplerIterator( | ||
sampler=self.sampler, | ||
batch_size=self.batch_size, | ||
drop_last=self.drop_last, | ||
) | ||
|
||
|
||
class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ramanishsingh can you try to remove this change, and fork the random sampler code here and make it an iterator instead of a generator? It will take a bit of refactoring but should make the save/load state more explicit, easier to reason about and handle.
If that works, check that the old RandomSampler works too as a drop-in replacement (it shoudl just fast-forward) to ensure users bringing their own (non-stateful) samplers will have things work as well