-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure microbatching works with Partially Sharded Data Parallel #325
Comments
Draft conceptual implementation, since I think this might come up soon.
Essentially these mappings just reshape the DP axis to now be (R, DP), with params behaving differently.
Grad accum now looks like: Replica = Axis("rep", num_replicas)
data = data.rearrange("{batch: (rep b) } -> ... rep (batch: b) ...", rep=Replica)
# normal grad_accum, but vmapped so there's a leading Replica axis on the results
grad_buffers = hax.vmap(grad_accum, Replica)(model, data) # e.g. embeddings grad is [Replica, Vocab, Embed]
grad_buffers = shard_with_axis_mapping(hax.mean(grad_buffers, axis=Replica), opt_state_mapping)
# do normal parameter updates, but ensure sharding
updates, opt_state = optimizer.update(opt_state, grad_buffers, params=model)
updates = hax.shard_with_axis_mapping(updates, param_mapping)
model = eqx.apply_updates(model, updates). |
fixed in #588 |
(Note that grad accumulation / microbatches might be broken) |
hrm |
dlwh
changed the title
Partially Sharded Data Parallel
Ensure microbatching works with Partially Sharded Data Parallel
May 28, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Need to make sure we didn't break microbatching
On CUDA especially (and probably everything except TPU), I think we should investigate what I think of in my head as "partially sharded data parallel" or "replica fsdp". (cf https://www.amazon.science/blog/near-linear-scaling-of-gigantic-model-training-on-aws) . This might even be a good idea on TPU.
The motivation is that within-box communication is much cheaper than across-box, and NVIDIA GPUs have a ton of memory.
I think this is actually straightforward with mesh parallelism: add a new leading dimension to the device mesh called "replica" so that it's
(replica, data, model)
and partition the batch axis across replica and data, while partitioning model parameters only across data, leaving model for TP as usual.Seems like you still might want to fully partition optimizer states, but I dunno
The text was updated successfully, but these errors were encountered: