Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoding checkpoint reshaping guide #349

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Conversation

tjruwase
Copy link
Collaborator

This PR is a step towards generalizing the universal checkpointing approach that enables arbitrary reshapes of 3D parallel checkpoints. This PR eliminated the hardcoding of BLOOM model architecture in the current implementation as follow:

  1. Client encodes any shape information required for extracting and merging tensor slices (e.g., slices to be averaged rather than concatenated). This information is included in the checkpoint file.
  2. Replace constant strings with symbolic constants defined by deeepspeed library.

Requires the companion DS PR.

@tjruwase
Copy link
Collaborator Author

@stas00, I don't intend for this to be merged. Rather, I am sharing this PR to get your feedback for the generalization effort. As discussed earlier, the core logic will eventually move to DS.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty clean to me. Thank you for working on that, Tunji!

I know it's not as much fun when you're now working on it alone and without an immediately applicable context.

megatron/checkpointing.py Outdated Show resolved Hide resolved
@tjruwase tjruwase marked this pull request as draft September 20, 2022 13:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants