Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix iterator overflow when gradient accumulation is 1 #35960

Merged
merged 1 commit into from
Jan 29, 2025

Conversation

winglian
Copy link
Contributor

@winglian winglian commented Jan 29, 2025

What does this PR do?

In the original part of the code:

            epoch_iterator = iter(epoch_dataloader)
            # We chunkify the epoch iterator into gradient accumulation steps `n` batches
            remainder = num_examples % args.gradient_accumulation_steps
            if remainder == 0:
                remainder = args.gradient_accumulation_steps
            update_step = -1
            total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1

if we assume args.gradient_accumulation_steps is 1 and the length of the iterator/num_examples is say 10, then N % 1 is always remainder = 0.
which then sets the remainder to an additional step.
Then for total updates, we know that it's 10 steps // 1 ga + 1, so total_updates = 11, which is one more than the actual number of row in the dataset/dataloader.

then when the loop calls self.get_batch_samples(epoch_iterator, num_batches) on the remainder there is no data left, which I believe is causing a segmentation fault for me.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr or @SunMarc

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@SunMarc
Copy link
Member

SunMarc commented Jan 29, 2025

There is indeed an issue there. We can go with your solution or re-write this to make it easier to follow, something like this.
I'll let @muellerzr decide.

remainder = num_examples % args.gradient_accumulation_steps

# Ensure that `remainder` logic doesn't add unnecessary steps
if remainder == 0:
    total_updates = steps_in_epoch // args.gradient_accumulation_steps
else:
    total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1

update_step = -1

# Loop through the calculated total updates
for _ in range(total_updates):
    update_step += 1

    # Properly handle the last step with the remainder logic
    if update_step == (total_updates - 1) and remainder != 0:
        num_batches = remainder
    else:
        num_batches = args.gradient_accumulation_steps

@SunMarc SunMarc requested a review from muellerzr January 29, 2025 17:39
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!

@muellerzr muellerzr merged commit 7547f55 into huggingface:main Jan 29, 2025
25 checks passed
@winglian
Copy link
Contributor Author

I like your solution much better.

There is indeed an issue there. We can go with your solution or re-write this to make it easier to follow, something like this. I'll let @muellerzr decide.

remainder = num_examples % args.gradient_accumulation_steps

# Ensure that `remainder` logic doesn't add unnecessary steps
if remainder == 0:
    total_updates = steps_in_epoch // args.gradient_accumulation_steps
else:
    total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1

update_step = -1

# Loop through the calculated total updates
for _ in range(total_updates):
    update_step += 1

    # Properly handle the last step with the remainder logic
    if update_step == (total_updates - 1) and remainder != 0:
        num_batches = remainder
    else:
        num_batches = args.gradient_accumulation_steps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants