-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Partially-sharded-data-parallel #588
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused. Can you explain (as a comment) why the data uid stuff is correct?
self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups) | ||
super().__init__(max_capacity, axis_resources) | ||
|
||
def _produce_batches(self) -> Iterator[PyTree]: | ||
one_item_generator = non_caching_cycle(self.item_dataset) | ||
batched = _batched(one_item_generator, self.local_batch_size) | ||
|
||
def batch_callback(global_begin, _): | ||
# global_begin is uid for DP/FSDP | ||
# DP_id * per_device_bs = global_begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a note explaining why this is correct (also so I can be sure I understand)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the PR description.
If we envision each process as a subgrid of the mesh for its devices, then there is a process grid that | ||
is a coarsened version of the mesh. This is the size of the process grid. | ||
Handles the case when different processes share the same data in TP. | ||
If we envision each process as a subgrid of the mesh for its devices, this is the position of the process |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't actually understand how process_mesh is still a valid abstraction/idea in the world with "non-contiguous" devices meshes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the PR description.
also can you merge main so that the TPU tests run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i think i'm convinced!
An explanation:
Here I will ignore the difference between DP and FSDP and assume there're only 2 axis (DP, TP) because I handled all DP, FSDP axis together
DP_FSDP_index = DP_index * FSDP_size + FSDP_index
Suppose we have this mesh of DP=n, TP=4: (d -> device. p-> process)
The 4 devices (e.g. d0d1d4d5) in every rows gets the same data and performs TP (shards the model). Each column receives different data.
We first take care of each process individually and ensure
This is handled by
local_device_mapping()
.From each device, we can extract its DP index by its position in the mesh and map it to a uid. In this example, we will have a mapping of
When we call
make_array_from_callback()
with the mesh, each device gets a slice of sizeper_device_batch_size
and thus the slice's start will beglobal_begin = DP_index * per_device_batch_size
. Thus we can extractDP_index
from the sliceSee
batch_callback()
inloader.py
, devices will getFor
p0
, this means devices with DP index 0 (d0&d1) will get the first half oflocal_batch
, and devices with DP index k (d2&d3) will get the second half. This satisfied the two bullet points above.Next, we take care of the process-level stuff and ensure
This is handled by
process_mesh_mapping()
.Each process:
Thus we get this mapping
The uid becomes the
shard_idx
of the dataloader. Thus, p0 and p1 as processes will get the same shard_idx because and will get the samelocal_batch
.When looking at device-level, we already ensure that d0&d1 gets the first half, d2&d3 gets the second half. This is also true for p1: d4&d5 gets the first half and d6&d7 gets the second half. Now because p0 and p1 have the same shard_idx, we further ensure that d0&d1&d4&d5 gets the same data, d2&d3&d6&d7 gets the same data.
Therefore, these 2 mapping function works for any DP/TP configuration :)