Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure microbatching works with Partially Sharded Data Parallel #325

Open
dlwh opened this issue Sep 26, 2023 · 4 comments
Open

Ensure microbatching works with Partially Sharded Data Parallel #325

dlwh opened this issue Sep 26, 2023 · 4 comments
Assignees
Milestone

Comments

@dlwh
Copy link
Member

dlwh commented Sep 26, 2023

Need to make sure we didn't break microbatching


On CUDA especially (and probably everything except TPU), I think we should investigate what I think of in my head as "partially sharded data parallel" or "replica fsdp". (cf https://www.amazon.science/blog/near-linear-scaling-of-gigantic-model-training-on-aws) . This might even be a good idea on TPU.

The motivation is that within-box communication is much cheaper than across-box, and NVIDIA GPUs have a ton of memory.

I think this is actually straightforward with mesh parallelism: add a new leading dimension to the device mesh called "replica" so that it's (replica, data, model) and partition the batch axis across replica and data, while partitioning model parameters only across data, leaving model for TP as usual.

Seems like you still might want to fully partition optimizer states, but I dunno

@dlwh dlwh added this to the GPU Parity milestone Sep 26, 2023
@dlwh dlwh added the p2 label Sep 26, 2023
@dlwh dlwh unassigned vadam5 Jan 23, 2024
@dlwh
Copy link
Member Author

dlwh commented Jan 31, 2024

Draft conceptual implementation, since I think this might come up soon.

  • Mesh is now (R, DP, TP)
  • There are now 3 mappings:
    • compute: {batch: (R, DP)} (you can map a logical axis to more than one physical)
    • params: {embed: dp, rep: R} (rep will be used for grad buffers)
    • opt_state: {embed: (r, dp)}
  • TP works as before.

Essentially these mappings just reshape the DP axis to now be (R, DP), with params behaving differently.

  • Data is sharded across R and DP : batch -> (R, DP)
  • params are sharded Embed -> (DP) (thus replicated across R)
  • opt states are sharded Embed -> (R, DP)
  • Grads are buffered so that there is one independent copy of gradients on each replica. That is, grads are [rep, *param.axes]
  • do grad accum on each replica independently.
  • reduce and reshard grads, so that they are now Embed -> (R, DP) and no longer have a rep axis

Grad accum now looks like:

Replica = Axis("rep", num_replicas)
data = data.rearrange("{batch: (rep b) } -> ... rep (batch: b) ...", rep=Replica)

# normal grad_accum, but vmapped so there's a leading Replica axis on the results
grad_buffers = hax.vmap(grad_accum, Replica)(model, data)  # e.g. embeddings grad is [Replica, Vocab, Embed]
grad_buffers = shard_with_axis_mapping(hax.mean(grad_buffers, axis=Replica), opt_state_mapping)

# do normal parameter updates, but ensure sharding
updates, opt_state = optimizer.update(opt_state, grad_buffers, params=model)
updates = hax.shard_with_axis_mapping(updates, param_mapping)
model = eqx.apply_updates(model, updates).

@dlwh dlwh self-assigned this Mar 1, 2024
@dlwh dlwh assigned blahBlahhhJ and unassigned dlwh May 11, 2024
@dlwh
Copy link
Member Author

dlwh commented May 28, 2024

fixed in #588

@dlwh dlwh closed this as completed May 28, 2024
@blahBlahhhJ
Copy link
Contributor

(Note that grad accumulation / microbatches might be broken)

@dlwh
Copy link
Member Author

dlwh commented May 28, 2024

hrm

@dlwh dlwh reopened this May 28, 2024
@dlwh dlwh changed the title Partially Sharded Data Parallel Ensure microbatching works with Partially Sharded Data Parallel May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants