Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix end of epoch StatefulDataLoader restart #1439

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

ramanishsingh
Copy link
Contributor

@ramanishsingh ramanishsingh commented Feb 3, 2025

Add tests to reproduce and fix #1437

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2025
@ramanishsingh ramanishsingh marked this pull request as draft February 3, 2025 23:36
Copy link

pytorch-bot bot commented Feb 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1439

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit da55b34 with merge base fe6b405 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ramanishsingh
Copy link
Contributor Author

This does not solve the problem as it just restarts the dataloader and produces the same batches again.

@gailweiss
Copy link

if its any help, while creating the issue i noticed that after loading the state dict, the resulting state dict in the dataloader is different from the one that was loaded - for example, by setting "samples_yielded" to 0 when the loaded one had 100 (see the prints in #1437 ), (and possibly more differences - I haven't checked). looking at the code in this commit, it seems that samples_yielded is being set manually - maybe that is the root of the problem?

update stateful_dataloader

run precommit

local changes

update test to test the order of batches

update test

update tests

revert changes in SDL

revert changes in SDL

update tests

run precommit
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 8136e63 to a074b50 Compare February 4, 2025 22:48
@ramanishsingh ramanishsingh marked this pull request as ready for review February 5, 2025 06:37
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler):
for _ in self.sampler_iter:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, sorry to nitpick, but wouldn't we actually not want to skip an epoch when reloading a state dict from an unfinished epoch? i.e. IMO it makes sense to recognise the difference between a state_dict obtained via this process:

for b in dl:
    sd_in = dl.state_dict()  # when the for ends, sd_in will describe an "about-to-finish" state

vs a state_dict obtained via this process:

for b in dl:
    pass
sd_out = dl.state_dict()  # sd_out will describe a "just finished" state

I think it makes sense to have an empty epoch immediately after loading sd_in, but a full one immediately after loading sd_out.

In particular, is it possible that issue #1437 is solved just by the new line self.samples_yielded = state_dict[self._SAMPLES_YIELDED]?

Copy link
Contributor Author

@ramanishsingh ramanishsingh Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gailweiss
I think the code changes take care of that. In the state_dict ['_iterator_finished']: False or True depending on if it is just- finishing or actually-finished.
When we load sd_in, the next epoch is still empty because state_dict ['_iterator_finished']: False

Is this the behavior you expect :

from torchdata.stateful_dataloader import StatefulDataLoader


def get_dl():
    d = list(range(10))
    return StatefulDataLoader(d, batch_size=1, shuffle=True)


dl = get_dl()
for i, b in enumerate(dl):
    if i == 0:
        print(i, b)
    sd_in = dl.state_dict()
print("sd_in", sd_in)


dl = get_dl()
dl.load_state_dict(sd_in)  # load the "about-to-finish" state
batches_after_sdin_load = []
for i, b in enumerate(dl):
    batches_after_sdin_load.append(b)
    if i == 0:
        print(i, b)

print("batches_after_sdin_load", batches_after_sdin_load)
dl = get_dl()
for i, b in enumerate(dl):
    if i == 0:
        print(i, b)

sd_out = (
    dl.state_dict()
)  # when the for ends, sd_out will describe a "just-finished" state
print("sd_out", sd_out)
dl = get_dl()
dl.load_state_dict(sd_out)  # load the "about-to-finish" state
batches_after_sdout_load = []
for i, b in enumerate(dl):
    batches_after_sdout_load.append(b)
    if i == 0:
        print(i, b)

print("batches_after_sdout_load", batches_after_sdout_load)

Output:

0 tensor([5])
sd_in {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': False}
batches_after_sdin_load []
0 tensor([5])
sd_out {'_index_sampler_state': None, '_sampler_iter_state': {'samples_yielded': 10, 'sampler_iter_state': {'generator': tensor([1, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8), 'yielded': 10}}, '_sampler_iter_yielded': 10, '_num_yielded': 10, '_IterableDataset_len_called': None, '_shared_seed': None, 'fetcher_state': None, 'dataset_state': None, '_iterator_finished': True}
0 tensor([2])
batches_after_sdout_load [tensor([2]), tensor([8]), tensor([1]), tensor([5]), tensor([6]), tensor([9]), tensor([3]), tensor([7]), tensor([0]), tensor([4])]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, yes, that's great! Thank you and sorry for the false complaint!

@ramanishsingh ramanishsingh marked this pull request as draft February 5, 2025 14:02
@@ -1441,6 +1444,154 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)


class TestMultiEpochSDL_shard0(TestCase):
def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you test this for num_workers=0 and say num_workers=2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed defaults values for num_workers.
Tests below run for both 0 and 2.

next(self.sampler_iter)

# Skip one epoch if we were at the end of the last epoch
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all samplers have __len__? This feels brittle to me

Comment on lines 119 to 120
for _ in self.sampler_iter:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If state_dict is saved at the end of epoch 1/start of epoch 2, why do we need to do this? What if the sampler is not stateful, would it happen twice due to lines 114? Something seems a bit off to me

@ramanishsingh ramanishsingh marked this pull request as ready for review February 5, 2025 22:56
@ramanishsingh
Copy link
Contributor Author

ramanishsingh commented Feb 5, 2025

@andrewkho
Thanks.
I took your implementation of BatchSamplerIterator from here.
I find that during the loading of the state dict, if the _StatefulRandomSamplerIterator is at its end, its self.next_yielded value is becoming None due to iter re-init from somewhere.
To tackle that, I am artificially making it 0 by checking if we are at the end of an epoch and exhausting the iterator (Line 534 stateful_dataloader.py) .
I think it is less brittle than checking the length of the sampler and skipping one whole epoch. Please lmk your thoughts.

update state dict if the iterator has finished

add comment about why were updating state dict

run precommit
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 4de1bb4 to 6d49b4f Compare February 7, 2025 07:05
@ramanishsingh
Copy link
Contributor Author

TLDR: After refactoring BatchSampler, the same batch sequence is repeated in the epoch following a reload due to _iterator_finished being True. Update the generator in the state_dict after each iteration to cache the latest state, ensuring RNG resumes correctly even if next_yielded is reset to 0.

Problem:
After breaking the BatchSampler into BatchSampler and _BatchSamplerIterator, we encountered an issue where the same sequence of batches is produced in the epoch immediately following a reload, mirroring the last epoch before saving the state_dict.

Root Cause:
This issue arises because the dl state_dict is saved after the epoch completes, resulting in _iterator_finished being set to True. To illustrate, consider the epoch after reloading as epoch 3. In the state_dict of the RandomSampler (a subset of the dl state_dict), key items include self.next_yielded and the state of the generator. When a StatefulDataLoader (SDL) is instantiated with num_workers = 0 and batches are retrieved, the iter method in SDL is invoked. This method utilizes next_iter_state (or the loaded_state_dict) to obtain an iterator. During this process, the generator, sampler_iter, etc., are reloaded. However, since _iterator_finished is True, the _StatefulSingleProcessDataLoaderIter that was generated is discarded, and a new one is created with state_dict=None. Consequently, we lose the RandomSampler state information because next_yielded is reset to 0, and the generator state remains at the start of epoch 2.

Proposed Solution:
While there may be more efficient solutions, one potential approach (that I have implemented) is to update the generator in the state_dict upon completing an iteration. By doing so, we cache the latest generator state, allowing us to resume RNG production from the correct point even when the RandomSampler is reset with next_yielded = 0.

@ramanishsingh
Copy link
Contributor Author

ramanishsingh commented Feb 10, 2025

For future ref:
https://github.com/pytorch/pytorch/blob/652880e8403b58ca44d482f200a8991b8b326e88/torch/utils/data/sampler.py#L190

In torch.utils.data RandomSampler, we are changing the state of the generator even if self.num_samples % n==0 and we dont even use any samples from that permutation. A more efficient (so that we don't generate a randperm if we dont need one) and simpler solution would be to add a check self.num_samples % n>0 and then generate a random permutation.

Comment on lines +120 to +122
def update_state_dict(self) -> None:
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"):
self.sampler_iter.update_state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving thoughts here:

  • this introduces a new (undocumented) API for samplers (update_state_dict()) which isn't great
  • If the alternative solution is to update the base sampler code, this means users bringing their own samplers won't work out of the box, which also isn't great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

self._dataset,
self._auto_collation,
self._collate_fn,
self._drop_last,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we remove the whitespace changes (ie run pre-commit to format all these files) so it's easier to review?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run precommit and these whitespaces are actually added by precommit. If I revert back everything I get lint errors.

@@ -42,13 +41,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.sampler.generator.set_state(state_dict[self._GENERATOR])
self.next_yielded = state_dict[self._YIELDED]

def update_state_dict(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ramanishsingh can you try to remove this change, and fork the random sampler code here and make it an iterator instead of a generator? It will take a bit of refactoring but should make the save/load state more explicit, easier to reason about and handle.

If that works, check that the old RandomSampler works too as a drop-in replacement (it shoudl just fast-forward) to ensure users bringing their own (non-stateful) samplers will have things work as well

reverse changes to sdl.py

generator to iterator

run precommit

update generator usage
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 0a90c04 to 39995a3 Compare February 11, 2025 06:09
@gailweiss
Copy link

gailweiss commented Feb 11, 2025

Hi, I tried to be clever and implement a temporary workaround myself, but no luck :)
In the process however, I wrote a bunch of tests that I think reflect how the stateful dataloader should behave, it currently fails several. Maybe they will be helpful? They also surface the fact that once the stateful_dataloaders do get randomly initialised (see #1440 ), even if by a user just making them with different generators at first, even more things break - so it seems like loading the state dict doesnt completely wipe the dataloader's state...?

Code:

import torch
from copy import deepcopy


# test support functions: printing with notes, comparing dictionaries, comparing epochs,
# getting dataloaders, running tests, etc

class PrefPrint:
    def __init__(self):
        self.pref = ""
    def print(self, *a, **kw):
        print(self.pref, *a, **kw)
    
printer = PrefPrint()

def pprint(*a, **kw):
    printer.print(*a, **kw)

def same_continuation(orig, loaded, el, expect_false=False):
    l1 = [b.item() for b in orig]
    l2 = [b.item() for b in loaded]
    # print(l1, l2)
    if len(l1) != el:
        pprint(f"orig dl's epoch wrong length: {l1} (expected {el})")
        return False
    if len(l2) != el:
        pprint(f"loaded dl's epoch wrong length: {l2} (expected {el}) (orig dl's epoch was correct length)")
        return False
    res = (l1 == l2)
    if (not res) and (not expect_false):
        pprint("orig vs loaded dl epochs:", l1, l2)
    return res

def equal_state_dicts(d1, d2):
    def _comp_dicts(d1, d2, pref=">>\t"):
        if sorted(list(d1.keys())) != sorted(list(d2.keys())):
            return f"{pref} diff keys: {list(sd1.keys())}, {list(sd2.keys())}"
        for k in d1.keys():
            if isinstance(d1[k], dict):
                return _comp_dicts(d1[k], d2[k], pref=f"{pref}in {k}: ")
            elif isinstance(d1[k], torch.Tensor):
                if False in (d1[k] == d2[k]):
                    return f"{pref} diff on {k}: {d1[k].tolist()} vs {d2[k].tolist()}"
            elif d1[k] != d2[k]:
                return f"{pref} diff on {k}: {d1[k]} vs {d2[k]}"
        return ""
    res = _comp_dicts(d1, d2)
    if res:
        pprint(res)
        return False
    return True

def get_dl(from_seed=None, gen=None):
    assert None in [from_seed, gen]
    d = list(range(n_samples))
    if None is not from_seed:
        gen = torch.Generator()
        gen.manual_seed(from_seed)
    return DataLoader(d, generator=gen, batch_size=1, shuffle=True)

def run_test(f, *a, **kw):
    printer.pref = f"in {f.__name__}:"
    res = f(*a, **kw)
    printer.pref = ""
    return res

def test(f):
    tests.append(f)
    return f

tests = []
n_samples = 10 # length all the dataloaders will be in the tests

# tests themselves


# 1. 2 different inits create different shuffles
@test
def diffshuff_as_standard():
    # n = 10 -> 10! shuffles -> highly unlikely to accidentally get same epoch
    return not same_continuation(get_dl(), get_dl(), n_samples, expect_false=True)

# 2. 2 different inits from the same generator create the same shuffles
@test
def sameshuff_when_asked():
    if not same_continuation(get_dl(from_seed=1), get_dl(from_seed=1), n_samples):
        pprint("mismatch on seed 1")
        return False
    gen1, gen2 = torch.Generator(), torch.Generator()
    gen1.seed()
    gen2.set_state(gen1.get_state())
    if not same_continuation(get_dl(gen=gen1), get_dl(gen=gen2), n_samples):
        pprint("mismatch on random gen")
        return False
    return True

# 3. getting state dict and loading it after own state has changed recovers 
# previous state (ie state dict properly detached once taken)
@test
def go_back():
    dl1 = get_dl()
    sd = dl1.state_dict()
    a1 = [b.item() for b in dl1]
    dl1.load_state_dict(sd)
    a2 = [b.item() for b in dl1]
    if not a1 == a2:
        pprint("doesnt walk back steps")
        return False
    return True

# 4. loading a state dict taken from the middle of an epoch continues that epoch
@test
def resume_from_partial():
    dl1, dl2 = get_dl(), get_dl()
    a1 = []
    p = 3
    for i, b in enumerate(dl1):
        if i == p:
            sd = dl1.state_dict()
        if i > p:
            a1.append(b.item())
    dl2.load_state_dict(sd)
    a2 = [b.item() for b in dl2]
    res = a1 == a2
    if not res:
        pprint("diff continuation from partial sd")
        pprint(a1, a2)
    if not same_continuation(dl1, dl2, n_samples):
        pprint("next epoch after partial sd not aligned")
        res = False
    return res

# 5. loading a state dict taken after a dataloader has left a partial epoch resumes (like that dataloader) from that dataloader's next full epoch
@test
def resume_after_partial():
    dl1, dl2 = get_dl(), get_dl()
    a = []
    for i, b in enumerate(dl1):
        a.append(b.item())
        if i > 3:
            break
    sd = dl1.state_dict()
    dl2.load_state_dict(sd)
    return same_continuation(dl1, dl2, n_samples)


# 6. loading a state_dict taken after a full epoch has completed continues smoothly from that dataloader's next epoch 
# this highlights the main issue raised in https://github.com/pytorch/data/issues/1437
@test
def resume_between():
    g1, g2 = 2, 2
    dl1, dl2 = get_dl(), get_dl()
    _ = [[b.item() for b in dl1] for _ in range(g1)]
    sd = dl1.state_dict()
    dl2.load_state_dict(sd)
    for i in range(g2):
        if not same_continuation(dl1, dl2, n_samples):  # should be full epochs here
            pprint(f"epoch {i} after load is broken")
            return False
    return True


# 7. loading a state_dict taken partway through an epoch, specifically at the last batch (but before leaving the loop) resumes at same point (i.e., inside epoch, with 0 batches left), and then continues to the same next epoch as the original dataloader
@test
def resume_end():
    dl1, dl2 = get_dl(), get_dl()
    l1 = []
    for i, b in enumerate(dl1):
        l1.append(b.item())
        sd = dl1.state_dict()
        if i == n_samples - 1:
            break  # dl1 not finished
    pprint("dl1 first epoch is:", l1)

    dl2.load_state_dict(sd)
    l = [b.item() for b in dl2]
    if len(l) > 0:
        print("resuming finishing state does not lead to empty epoch")
        return False
    if not same_continuation(dl1, dl2, n_samples):
        pprint("resuming finishing state does not move (after finishing) to same next epoch")
        pprint("is this an off-by-one? dl1 next next epoch would be:", [b.item() for b in dl1])
        return False
    return True

# dataloader variants
from torchdata.stateful_dataloader import StatefulDataLoader as _DataLoader
import torch
from copy import deepcopy

class LoudDataLoader(_DataLoader):
    def load_state_dict(self, state_dict):
        sd1 = deepcopy(state_dict)
        super(LoudDataLoader, self).load_state_dict(sd1)
        equal_state_dicts(state_dict, self.state_dict())  # will print a difference if it finds one

# DataLoader to fix https://github.com/pytorch/data/issues/1440
class LoudDataLoader1440(LoudDataLoader):
    def __init__(self, *a, **kw):
        if None is kw.get("generator", None):
            kw["generator"] = torch.Generator()
            kw["generator"].seed()  # for some reason important for getting it going
        super().__init__(*a, **kw)



dlclasses = {"base":_DataLoader, "loud":LoudDataLoader, "loud1440":LoudDataLoader1440}
results = {}
for n, DLC in dlclasses.items():
    DataLoader = DLC
    print(f"\n================\nrunning {n} stateful dataloader tests")
    results[n] = {f.__name__ : run_test(f) for f in tests}

print("\n\n=======")
for n, r in results.items():
    print(f"\n=======\n {n} stateful dataloader test results:")
    names = [f.__name__ for f in tests]
    print("\n".join(f"{n}: \t\t[{r[n]}]" for n in names))

output:

================
running base stateful dataloader tests
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_between: loaded dl's epoch wrong length: [] (expected 10) (orig dl's epoch was correct length)
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]
in resume_end: orig vs loaded dl epochs: [2, 3, 1, 8, 9, 0, 6, 7, 4, 5] [2, 8, 1, 5, 6, 9, 3, 7, 0, 4]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [2, 9, 8, 3, 6, 7, 1, 0, 4, 5]

================
running loud stateful dataloader tests
in go_back: >>	in _index_sampler_state:  diff on samples_yielded: 0 vs 10
in resume_from_partial: >>	in _index_sampler_state:  diff on samples_yielded: 4 vs 0
in resume_after_partial: >>	in _index_sampler_state:  diff on samples_yielded: 5 vs 0
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_between: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_between: loaded dl's epoch wrong length: [] (expected 10) (orig dl's epoch was correct length)
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]
in resume_end: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_end: orig vs loaded dl epochs: [2, 3, 1, 8, 9, 0, 6, 7, 4, 5] [2, 8, 1, 5, 6, 9, 3, 7, 0, 4]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [2, 9, 8, 3, 6, 7, 1, 0, 4, 5]

================
running loud1440 stateful dataloader tests
in go_back: >>	in _index_sampler_state:  diff on samples_yielded: 0 vs 10
in go_back: doesnt walk back steps
in resume_from_partial: >>	in _index_sampler_state:  diff on samples_yielded: 4 vs 0
in resume_from_partial: diff continuation from partial sd
in resume_from_partial: [8, 9, 4, 1, 2, 3] [6, 8, 4, 2, 1, 7]
in resume_from_partial: orig vs loaded dl epochs: [2, 9, 6, 1, 5, 8, 3, 0, 4, 7] [8, 9, 0, 6, 3, 4, 7, 5, 1, 2]
in resume_from_partial: next epoch after partial sd not aligned
in resume_after_partial: >>	in _index_sampler_state:  diff on samples_yielded: 5 vs 0
in resume_after_partial: loaded dl's epoch wrong length: [4, 3, 2, 5, 1] (expected 10) (orig dl's epoch was correct length)
in resume_between: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_between: loaded dl's epoch wrong length: [] (expected 10) (orig dl's epoch was correct length)
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [4, 6, 7, 9, 2, 3, 0, 5, 1, 8]
in resume_end: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_end: orig vs loaded dl epochs: [5, 3, 1, 6, 7, 9, 2, 8, 4, 0] [7, 8, 3, 9, 4, 2, 1, 0, 5, 6]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [0, 5, 6, 3, 4, 8, 2, 1, 7, 9]


=======

=======
 base stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

=======
 loud stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

=======
 loud1440 stateful dataloader test results:
diffshuff_as_standard: 		[True]
sameshuff_when_asked: 		[True]
go_back: 		[False]
resume_from_partial: 		[False]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

@ramanishsingh
Copy link
Contributor Author

@gailweiss can you try these examples using the code on this branch? I guess the code in this branch should be working fine (except for the random generator thing, which is taken care of in #1441).

Comment on lines 42 to 43
if self.chunk_index < num_full_chunks:
if self.perm is None or not self.perm:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this chunking logic? I don't think this will give the correct distribution, ie it will depend on num_chunks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @andrewkho ,
that's how they are doing it in https://github.com/pytorch/pytorch/blob/b18e3c01aa8bfee37078b4a06cef4361bf63b36b/torch/utils/data/sampler.py#L177
Generate 32 random numbers at once and yield from them.
In their case self.num_samples // 32 is essentially num_full_chunks.

Comment on lines 59 to 65
if self.perm is None or not self.perm:
self.perm = torch.randint(
high=self.n,
size=(remainder,),
dtype=torch.int64,
generator=self.sampler.generator,
).tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do implement with less repeated code and if/else branches? I think we need next to be as lightweight as possible.

I would suggest we generate the sequence once (ie at startup/load_state_dict or at the beginning of the function).

eg

self.seq = generate_sequence(replacement=True)
ret = self.seq[self.i]
self.i += 1
return ret

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the chunking logic present in random sampler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, chunking is present in the torch.utils.data RandomSampler.

@gailweiss
Copy link

@ramanishsingh sure thing!

It seems better, but not all fixed. And sadly still breaks if the dataloaders are initiated with random generators. New output for same code:

================
running base stateful dataloader tests
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]

================
running loud stateful dataloader tests
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]

================
running loud1440 stateful dataloader tests
in go_back: doesnt walk back steps
in resume_from_partial: diff continuation from partial sd
in resume_from_partial: [5, 2, 6, 7, 9, 1] [9, 2, 3, 7, 8, 6]
in resume_from_partial: orig vs loaded dl epochs: [3, 5, 2, 6, 1, 4, 0, 7, 9, 8] [2, 9, 5, 6, 4, 7, 0, 1, 8, 3]
in resume_from_partial: next epoch after partial sd not aligned
in resume_after_partial: loaded dl's epoch wrong length: [8, 5, 7, 0, 9] (expected 10) (orig dl's epoch was correct length)
in resume_between: orig vs loaded dl epochs: [8, 2, 9, 3, 0, 5, 4, 7, 1, 6] [0, 5, 6, 1, 3, 4, 8, 7, 2, 9]
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [7, 3, 8, 4, 0, 2, 6, 1, 9, 5]
in resume_end: orig vs loaded dl epochs: [6, 9, 1, 0, 3, 4, 7, 5, 8, 2] [7, 2, 4, 9, 5, 1, 6, 3, 0, 8]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [3, 6, 7, 8, 2, 1, 9, 5, 0, 4]


=======

=======
 base stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[True]
resume_end: 		[True]

=======
 loud stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[True]
resume_end: 		[True]

=======
 loud1440 stateful dataloader test results:
diffshuff_as_standard: 		[True]
sameshuff_when_asked: 		[True]
go_back: 		[False]
resume_from_partial: 		[False]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
4 participants