Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix end of epoch StatefulDataLoader restart #1439

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 153 additions & 2 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def test(self):
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)
it = iter(dl)
# Fetch at least one batch from each worker
Expand All @@ -1325,7 +1325,10 @@ def test(self):
if num_workers > 0:
for i in range(num_workers):
# Ensure worker state is stored only once if the dataset is also the iterator
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
self.assertEqual(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"],
None,
)
self.assertTrue(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
"dataset_iter_state"
Expand Down Expand Up @@ -1441,6 +1444,154 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)


class TestMultiEpochSDL_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _run(self, data_size, num_workers, batch_size, shuffle):
dl1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 2 epochs and count the number of items yielded
num_items_yielded = 0
dl1_items = []
for _ in range(2):
for batch in dl1:
dl1_items.append(batch)
num_items_yielded += 1
# Save the state dict
state_dict = dl1.state_dict()
# Create a new StatefulDataLoader instance and load the state dict
new_dl1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
new_dl1.load_state_dict(state_dict)
# Run through the new dataloader for another 2 epochs and count the number of items yielded
additional_num_items_yielded = 0
for i in range(2):
epoch_num_items_yielded = 0
for batch in new_dl1:
dl1_items.append(batch)
epoch_num_items_yielded += 1
additional_num_items_yielded += epoch_num_items_yielded
# Check that the total number of items yielded is correct
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)

# now run a second dataloder for 4 epochs and check if the order is same.
dl2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dl2_items = []
for _ in range(4):
for batch in dl2:
dl2_items.append(batch)

self.assertEqual(dl1_items, dl2_items)

def test_main_process(self):
self._run(100, 0, 1, False)

def test_multiprocess(self):
self._run(100, 2, 1, False)

def test_main_process_shuffle(self):
self._run(100, 0, 1, True)

def test_multiprocess_shuffle(self):
self._run(100, 2, 1, True)


class TestEndOfEpochBehavior_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int:
num_items_yielded = 0
for batch in data_loader:
num_items_yielded += 1
return num_items_yielded

def _run(self, data_size, num_workers, batch_size, shuffle):
dl = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 1 epoch and count the number of items yielded
num_items_yielded = 0

for batch in dl:
num_items_yielded += 1
sd_in = dl.state_dict()
sd_out = dl.state_dict()

self.assertEqual(num_items_yielded, data_size)

# Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch
dl_sd_in = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dl_sd_in.load_state_dict(sd_in)

# Run through the new dataloader for 1 epoch and count the number of items yielded
# num_items_yielded should be 0 since the state dict was saved before the end of epoch
num_items_yielded = self._count_items_yielded(dl_sd_in)
self.assertEqual(num_items_yielded, 0)

# Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch
dl_sd_out = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dl_sd_out.load_state_dict(sd_out)

# Run through the new dataloader for 1 epoch and count the number of items yielded
# num_items_yielded should be data_size since the state dict was saved after the end of epoch
num_items_yielded = self._count_items_yielded(dl_sd_out)
self.assertEqual(num_items_yielded, data_size)

def test_main_process(self):
self._run(100, 0, 1, False)

def test_multiprocess(self):
self._run(100, 2, 1, False)

def test_main_process_shuffle(self):
self._run(100, 0, 1, True)

def test_multiprocess_shuffle(self):
self._run(100, 2, 1, True)


class TestMultiEpochState_shard0(TestCase):
def get_iterable_dl(self, pw, num_workers):
data_size = [25, 50, 100, 75]
Expand Down
95 changes: 50 additions & 45 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __next__(self) -> int:

self.yielded = self.next_yielded
self.next_yielded = None

val = next(self.parent_iterator)
self.yielded += 1
return val
Expand All @@ -42,13 +41,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.sampler.generator.set_state(state_dict[self._GENERATOR])
self.next_yielded = state_dict[self._YIELDED]

def update_state_dict(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ramanishsingh can you try to remove this change, and fork the random sampler code here and make it an iterator instead of a generator? It will take a bit of refactoring but should make the save/load state more explicit, easier to reason about and handle.

If that works, check that the old RandomSampler works too as a drop-in replacement (it shoudl just fast-forward) to ensure users bringing their own (non-stateful) samplers will have things work as well

self.generator_state = self.sampler.generator.get_state()
self.yielded = 0

def state_dict(self) -> Dict[str, Any]:
return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded}


class RandomSampler(torch.utils.data.sampler.RandomSampler):
def __init__(
self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None
self,
data_source: Sized,
replacement: bool = False,
num_samples: Optional[int] = None,
generator=None,
):
if generator is None:
# Ensure that underlying sampler has something repeatable
Expand All @@ -60,16 +67,30 @@ def __iter__(self):
return _StatefulRandomSamplerIterator(self, super().__iter__())


class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful):
class _BatchSamplerIterator(Iterator[list[int]], Stateful):
_SAMPLES_YIELDED = "samples_yielded"
_SAMPLER_STATE = "sampler_state"
_SAMPLER_ITER_STATE = "sampler_iter_state"

def __init__(self, sampler, batch_size, drop_last):
super().__init__(sampler, batch_size, drop_last)
def __init__(self, sampler, batch_size: int, drop_last: bool):
self.sampler = sampler
self.sampler_iter = iter(self.sampler)
self.batch_size = batch_size
self.drop_last = drop_last
self.samples_yielded = 0
self.next_yielded = None
self.sampler_iter = iter(sampler)

def __next__(self) -> list[int]:
batch = []
try:
for _ in range(self.batch_size):
batch.append(next(self.sampler_iter))
self.samples_yielded += 1
return batch
except StopIteration:
if self.drop_last or len(batch) == 0:
raise StopIteration
else:
return batch

def state_dict(self) -> Dict[str, Any]:
sd: Dict[str, Any] = {self._SAMPLES_YIELDED: self.samples_yielded}
Expand All @@ -80,7 +101,7 @@ def state_dict(self) -> Dict[str, Any]:
return sd

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.next_yielded = state_dict[self._SAMPLES_YIELDED]
self.samples_yielded = state_dict[self._SAMPLES_YIELDED]
if self._SAMPLER_STATE in state_dict:
assert isinstance(self.sampler, Stateful)
self.sampler.load_state_dict(state_dict[self._SAMPLER_STATE])
Expand All @@ -89,44 +110,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
assert isinstance(self.sampler_iter, Stateful)
self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE])

if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance(
self.sampler, _InfiniteConstantSampler
):
# We skip x samples if underlying sampler is not stateful
for _ in range(self.samples_yielded):
next(self.sampler_iter)

def update_state_dict(self) -> None:
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"):
self.sampler_iter.update_state_dict()
Comment on lines +164 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving thoughts here:

  • this introduces a new (undocumented) API for samplers (update_state_dict()) which isn't great
  • If the alternative solution is to update the base sampler code, this means users bringing their own samplers won't work out of the box, which also isn't great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, sorry to nitpick, but wouldn't we actually not want to skip an epoch when reloading a state dict from an unfinished epoch? i.e. IMO it makes sense to recognise the difference between a state_dict obtained via this process:

for b in dl:
    sd_in = dl.state_dict()  # when the for ends, sd_in will describe an "about-to-finish" state

vs a state_dict obtained via this process:

for b in dl:
    pass
sd_out = dl.state_dict()  # sd_out will describe a "just finished" state

I think it makes sense to have an empty epoch immediately after loading sd_in, but a full one immediately after loading sd_out.

In particular, is it possible that issue #1437 is solved just by the new line self.samples_yielded = state_dict[self._SAMPLES_YIELDED]?

Copy link
Contributor Author

@ramanishsingh ramanishsingh Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gailweiss
I think the code changes take care of that. In the state_dict ['_iterator_finished']: False or True depending on if it is just- finishing or actually-finished.
When we load sd_in, the next epoch is still empty because state_dict ['_iterator_finished']: False

Is this the behavior you expect :

from torchdata.stateful_dataloader import StatefulDataLoader


def get_dl():
    d = list(range(10))
    return StatefulDataLoader(d, batch_size=1, shuffle=True)


dl = get_dl()
for i, b in enumerate(dl):
    if i == 0:
        print(i, b)
    sd_in = dl.state_dict()
print("sd_in", sd_in)


dl = get_dl()
dl.load_state_dict(sd_in)  # load the "about-to-finish" state
batches_after_sdin_load = []
for i, b in enumerate(dl):
    batches_after_sdin_load.append(b)
    if i == 0:
        print(i, b)

print("batches_after_sdin_load", batches_after_sdin_load)
dl = get_dl()
for i, b in enumerate(dl):
    if i == 0:
        print(i, b)

sd_out = (
    dl.state_dict()
)  # when the for ends, sd_out will describe a "just-finished" state
print("sd_out", sd_out)
dl = get_dl()
dl.load_state_dict(sd_out)  # load the "about-to-finish" state
batches_after_sdout_load = []
for i, b in enumerate(dl):
    batches_after_sdout_load.append(b)
    if i == 0:
        print(i, b)

print("batches_after_sdout_load", batches_after_sdout_load)

Output:

0 tensor([5])
sd_in {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': False}
batches_after_sdin_load []
0 tensor([5])
sd_out {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
0 tensor([2])
batches_after_sdout_load [tensor([2]), tensor([8]), tensor([1]), tensor([5]), tensor([6]), tensor([9]), tensor([3]), tensor([7]), tensor([0]), tensor([4])]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, yes, that's great! Thank you and sorry for the false complaint!


class BatchSampler(torch.utils.data.sampler.BatchSampler):
def __init__(self, sampler, batch_size, drop_last):
super().__init__(sampler, batch_size, drop_last)

def __iter__(self):
if self.next_yielded is not None:
self.samples_yielded = self.next_yielded
if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance(
self.sampler, _InfiniteConstantSampler
):
# We skip x samples if underlying sampler is not stateful
for _ in range(self.next_yielded):
next(self.sampler_iter)
self.next_yielded = None
elif self.samples_yielded > 0:
# don't re-create sampler_iter unless necessary, we may already have one from init
self.sampler_iter = iter(self.sampler)
self.samples_yielded = 0

if self.drop_last:
while True:
try:
batch = []
for _ in range(self.batch_size):
batch.append(next(self.sampler_iter))
self.samples_yielded += 1
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size
idx_in_batch = 0
for idx in self.sampler_iter:
self.samples_yielded += 1
batch[idx_in_batch] = idx
idx_in_batch += 1
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]
return _BatchSamplerIterator(
sampler=self.sampler,
batch_size=self.batch_size,
drop_last=self.drop_last,
)


class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler):
Expand Down
Loading