Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP Support #51

Open
andrew-bydlon opened this issue Apr 16, 2024 · 3 comments
Open

FSDP Support #51

andrew-bydlon opened this issue Apr 16, 2024 · 3 comments

Comments

@andrew-bydlon
Copy link

This is a bit of a technical challenge and/or question. Both I-JEPA and V-JEPA use DDP and not FSDP. This puts an inherent cap on the size of models that are used, the size of the GPU memory.

I'm wondering if there is any thought being put into the support of JEPA with FSDP. In my mind, the flow would be to

  1. Ensure that the model sharding of the target and context encoder is equivalent.
  2. Update only the sharded parameters on a particular node (could even be a performance improvement vs. DDP).
  3. During forward passes, share the locally updated weights to all nodes.

I attempted to implement something like this on my side, though FSDP seems to shard the parameters a bit sporadically, e.g. not following 1. above.

Any suggestions?

@russellhowes
Copy link

russellhowes commented Apr 16, 2024

Hi Andrew, thanks for reaching out. FSDP support is on our task list, but we haven't implemented it yet.

I'll keep this open, and check in with any updates on our side (or keep an eye out if you end up getting something working 🙂)

@andrew-bydlon
Copy link
Author

Thanks Russell. Look forward to hearing more!

I've tried to implement it, but my loss starts creeping up after a few 1000 steps. I hypothesize that fsdp wrapping each module seems to yield different flattened parameters per GPU, but not totally sure. I keep predictor + context in one module, but include the encoder in my auto wrap policy, which is slightly different than this repo. So maybe it would work better here.

Hope the info helps. Will update if something magical happens 😅

@andrew-bydlon
Copy link
Author

This can be achieved using fsdp.summon_full_parameters() for the ema updates :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants