Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal checkpoint for zero stage 3 #5475

Merged
merged 21 commits into from
Jun 26, 2024

Conversation

xylian86
Copy link
Contributor

This PR enables the universal checkpoint for zero stage 3.

Notes:

  • The current implementation supports Data parallelism.
  • Development is ongoing for universal checkpoint Stage 3 with tensor-slicing model parallelism.
  • Pipeline parallelism is not supported by ZeRO Stage 3, and hence is not included in this universal checkpoint implementation.

In this PR:

  • I've updated deepspeed/checkpoint/ds_to_universal.py to support converting Zero checkpoints into Universal checkpoints.
  • I've updated deepspeed/runtime/zero/stage3.py to enable loading Universal checkpoints using the Stage 3 optimizer.

@xylian86
Copy link
Contributor Author

xylian86 commented Apr 29, 2024 via email

@tjruwase tjruwase requested review from samadejacobs, tohtana and lekurile and removed request for mrwyattii May 2, 2024 23:05
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a test for universal checkpointing. It currently supports DP scaling only but it would be good to test ZeRO3 feature using this test. You can just add "3" to the test argument.

deepspeed/runtime/zero/stage3.py Show resolved Hide resolved
deepspeed/checkpoint/ds_to_universal.py Show resolved Hide resolved
return int(text) if text.isdigit() else text


def natural_keys(text):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for introducing this interesting approach.
We have a similar sorting in _merge_zero_shards but are using a different approach but it is not good to have two different sorting implementations for the same purpose. Can you replace this one with natural_keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions!
For this natural_keys function, I actually reuse it from zero_to_fp32.py.

You’re right; it’s not ideal to have two different implementations for the same function. How about I replace the one in _merge_zero_shards with this natural_keys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good, thank you

@xylian86
Copy link
Contributor Author

xylian86 commented Jun 3, 2024

image

Convergence curve for ZeRO 3 using the current implementation

@xylian86 xylian86 requested a review from loadams as a code owner June 5, 2024 13:20
@tohtana tohtana enabled auto-merge June 24, 2024 16:13
@tohtana tohtana added this pull request to the merge queue Jun 26, 2024
Merged via the queue into deepspeedai:master with commit d2b1d7f Jun 26, 2024
15 checks passed
loadams added a commit that referenced this pull request Jul 3, 2024
loadams added a commit that referenced this pull request Jul 3, 2024
@Lomax314
Copy link

Lomax314 commented Jul 30, 2024

Hello, thank you for your work!
I've encountered some issues during my practice and would appreciate your help in analyzing them.
I am using Zero Stage=3 (Only DeepSpeed)for training and use model.save_checkpoint to save the optimizer state(optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)). The conversion of the original optimizer state to the universal_checkpoint format and loading the uni format part did not report any errors. However, the training loss did not continue from where it left off but started to gradually decrease from 9, which is a very high loss value. I am puzzled by this and would appreciate your assistance in analyzing it.

Additionally, I used model.state_dict() and uni_model.state_dict() to get the model state before and after the conversion and compared them using compare_state_dicts. The result triggered an assertion, as torch.equal(s0.to('cpu'), s1.to('cpu')) evaluated to False.

@ArtificialZeng
Copy link

This PR enables the universal checkpoint for zero stage 3.

Notes:

  • The current implementation supports Data parallelism.
  • Development is ongoing for universal checkpoint Stage 3 with tensor-slicing model parallelism.
  • Pipeline parallelism is not supported by ZeRO Stage 3, and hence is not included in this universal checkpoint implementation.

In this PR:

  • I've updated deepspeed/checkpoint/ds_to_universal.py to support converting Zero checkpoints into Universal checkpoints.
  • I've updated deepspeed/runtime/zero/stage3.py to enable loading Universal checkpoints using the Stage 3 optimizer.

how to load

@xylian86
Copy link
Contributor Author

@ArtificialZeng We have released examples in the Megatron-DeepSpeed repository. You can find them at:
https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing.
Please let us know if you encounter any issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants