Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 26, 2024
1 parent 130c212 commit 68ed94e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
17 changes: 10 additions & 7 deletions src/levanter/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,39 @@ def local_device_grid_positions(mesh, process_index: Optional[int] = None) -> tu
Analogous to what np.where would return."""
if process_index is None:
process_index = jax.process_index()
# our device indices are [process_index * num_devices_per_node, (process_index + 1) * num_devices_per_node)
# we could be clever here and do math to figure out where we are in the grid, but it's simpler and less
# fragile to just search the grid for our devices

my_device_pos = np.vectorize(lambda dev: dev.process_index == process_index)(mesh.devices)
return my_device_pos.nonzero()


def local_devices_mapping(mesh: Mesh, process_index: Optional[int] = None) -> dict[int, int]:
"""Returns a mapping from local devices' DP/FSDP group index in global mesh to local indices"""
"""
Handles the case when different devices in same process share the same data in TP.
Returns a mapping from local devices' DP/FSDP group index in global mesh to local indices
"""
local_device_pos = local_device_grid_positions(mesh, process_index)[:2] # first 2 axes are DP axes.
result = {}
uid = 0
for local_device_index in range(len(local_device_pos[0])):
key = local_device_pos[0][local_device_index] * mesh.devices.shape[1] + local_device_pos[1][local_device_index]
if key not in result:
result[key] = uid # in case of TP=2, local device 0 and 2 will be mapped to same key.
# when two devices maps to the same key (different TP index), they will get the same data
result[key] = uid
uid += 1
return result


def process_mesh_mapping(mesh) -> dict[int, int]:
"""
Handles the case when different processes share the same data in TP.
If we envision each process as a subgrid of the mesh for its devices, this is the position of the process
in the coarsened process-level mesh
"""
devices = mesh.devices
result = {} # maps process index to leftmost process index in DP/FSDP group
result = {}
uid = 0
leftmost2uid = {}

# basic logic: process index -> upper-left device -> TP index 0 device -> process index -> uid
for process_index in range(jax.process_count()):
tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, process_index)]
tmp[-1] = 0 # we want the device with TP group index 0 in the same DP/FSDP group
Expand Down
8 changes: 1 addition & 7 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,7 @@ def device_mesh(self) -> Mesh:
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
allow_split_physical_axes=True,
)
# devices = jax.devices()
# devices = np.array(devices).reshape(
# self.replica_axis_size, self.data_axis_size // self.replica_axis_size, self.model_axis_size
# )

return Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))

@property
Expand Down Expand Up @@ -746,9 +743,6 @@ def _maybe_set_id(self):

# we can't do this in post_init because we don't want to call jax.device_count before calling distributed.initialize
def _validate_and_set_defaults(self):
if self.model_axis_size > 4:
raise ValueError(f"model axis size ({self.model_axis_size}) should not be greater than 4")

if jax.device_count() % self.model_axis_size != 0:
raise ValueError(
f"num_devices ({jax.device_count()}) is not divisible by model_axis_size ({self.model_axis_size})"
Expand Down

0 comments on commit 68ed94e

Please sign in to comment.