Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partially-sharded-data-parallel #588

Merged
merged 48 commits into from
May 28, 2024
Merged

Partially-sharded-data-parallel #588

merged 48 commits into from
May 28, 2024

Conversation

blahBlahhhJ
Copy link
Contributor

@blahBlahhhJ blahBlahhhJ commented May 16, 2024

An explanation:

Here I will ignore the difference between DP and FSDP and assume there're only 2 axis (DP, TP) because I handled all DP, FSDP axis together DP_FSDP_index = DP_index * FSDP_size + FSDP_index

Suppose we have this mesh of DP=n, TP=4: (d -> device. p-> process)

         TP0    TP1     TP2      TP3
DP0    (d0 p0) (d1 p0) (d4 p1) (d5 p1)
DP1    (d8 p2) (d9 p2) (d12 p3) (d13 p3)
  ... a bunch of other DP groups in between so that devices are non-contiguous
DPk     (d2 p0) (d3 p0) (d6 p1) (d7 p1)

The 4 devices (e.g. d0d1d4d5) in every rows gets the same data and performs TP (shards the model). Each column receives different data.

We first take care of each process individually and ensure

  • devices in the same DP group (with different TP index) receives the same data
  • devices in different DP groups receive different data

This is handled by local_device_mapping().

From each device, we can extract its DP index by its position in the mesh and map it to a uid. In this example, we will have a mapping of

p0: {0: 0, k: 1}
p1: {0: 0, k: 1}
p2: {1: 0, ...}
p3: {1: 0, ...}

When we call make_array_from_callback() with the mesh, each device gets a slice of size per_device_batch_size and thus the slice's start will be global_begin = DP_index * per_device_batch_size. Thus we can extract DP_index from the slice

DP_index = global_begin // per_device_batch_size

See batch_callback() in loader.py, devices will get

local_batch[uid*per_device_batch_size : (uid+1)*per_device_batch_size]

For p0, this means devices with DP index 0 (d0&d1) will get the first half of local_batch, and devices with DP index k (d2&d3) will get the second half. This satisfied the two bullet points above.

Next, we take care of the process-level stuff and ensure

  • different processes' devices that are in the same DP group (different TP index) receive the same data.
  • different processes' devices that are in different DP groups receive different data.

This is handled by process_mesh_mapping().

Each process:

  1. maps to its upper left device in the mesh
  2. maps to the location of the device
  3. maps to the location of the device with TP0 in the same DP group
  4. maps to the device in that TP0 location
  5. maps to the process of that device
  6. maps to a uid
p0 -> d0 -> (DP0 TP0) -> (DP0 TP0) -> d0 -> p0 -> 0
p1 -> d4 -> (DP0 TP2) -> (DP0 TP0) -> d0 -> p0 -> 0
p2 -> d8 -> (DP1 TP0) -> (DP1 TP0) -> d8 -> p2 -> 1
p3 -> d12-> (DP1 TP2) -> (DP1 TP0) -> d8 -> p2 -> 1

Thus we get this mapping

{p0: 0, p1: 0, p2: 1, p3: 1}

The uid becomes the shard_idx of the dataloader. Thus, p0 and p1 as processes will get the same shard_idx because and will get the same local_batch.

When looking at device-level, we already ensure that d0&d1 gets the first half, d2&d3 gets the second half. This is also true for p1: d4&d5 gets the first half and d6&d7 gets the second half. Now because p0 and p1 have the same shard_idx, we further ensure that d0&d1&d4&d5 gets the same data, d2&d3&d6&d7 gets the same data.

Therefore, these 2 mapping function works for any DP/TP configuration :)

README.md Outdated Show resolved Hide resolved
docs/Configuration-Guide.md Outdated Show resolved Hide resolved
src/levanter/data/loader.py Outdated Show resolved Hide resolved
@blahBlahhhJ blahBlahhhJ marked this pull request as ready for review May 26, 2024 20:41
@blahBlahhhJ blahBlahhhJ requested a review from dlwh May 26, 2024 20:41
Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. Can you explain (as a comment) why the data uid stuff is correct?

src/levanter/trainer.py Outdated Show resolved Hide resolved
src/levanter/trainer.py Outdated Show resolved Hide resolved
src/levanter/trainer.py Outdated Show resolved Hide resolved
src/levanter/trainer.py Outdated Show resolved Hide resolved
src/levanter/trainer.py Outdated Show resolved Hide resolved
src/levanter/data/loader.py Outdated Show resolved Hide resolved
self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups)
super().__init__(max_capacity, axis_resources)

def _produce_batches(self) -> Iterator[PyTree]:
one_item_generator = non_caching_cycle(self.item_dataset)
batched = _batched(one_item_generator, self.local_batch_size)

def batch_callback(global_begin, _):
# global_begin is uid for DP/FSDP
# DP_id * per_device_bs = global_begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a note explaining why this is correct (also so I can be sure I understand)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the PR description.

src/levanter/mesh.py Show resolved Hide resolved
src/levanter/mesh.py Show resolved Hide resolved
If we envision each process as a subgrid of the mesh for its devices, then there is a process grid that
is a coarsened version of the mesh. This is the size of the process grid.
Handles the case when different processes share the same data in TP.
If we envision each process as a subgrid of the mesh for its devices, this is the position of the process
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't actually understand how process_mesh is still a valid abstraction/idea in the world with "non-contiguous" devices meshes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the PR description.

@dlwh
Copy link
Member

dlwh commented May 27, 2024

also can you merge main so that the TPU tests run

src/levanter/data/loader.py Outdated Show resolved Hide resolved
Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i think i'm convinced!

@dlwh dlwh merged commit ed3c6f1 into main May 28, 2024
5 checks passed
@dlwh dlwh deleted the psdp branch May 28, 2024 06:29
rjpower pushed a commit to rjpower/levanter that referenced this pull request May 29, 2024
Ivan-Zhou pushed a commit that referenced this pull request May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants