Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 25, 2024
1 parent 3313002 commit c34a573
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/levanter/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def local_devices_mapping(mesh: Mesh, process_index: Optional[int] = None) -> di
"""Returns a mapping from local devices' DP/FSDP group index in global mesh to local indices"""
local_device_pos = local_device_grid_positions(mesh, process_index)[:2] # first 2 axes are DP axes.
result = {}
j = 0
for i in range(len(local_device_pos[0])):
key = local_device_pos[0][i] * mesh.devices.shape[1] + local_device_pos[1][i]
uid = 0
for local_device_index in range(len(local_device_pos[0])):
key = local_device_pos[0][local_device_index] * mesh.devices.shape[1] + local_device_pos[1][local_device_index]
if key not in result:
result[key] = j # in case of TP=2, local device 0 and 2 will be mapped to same key.
j += 1
result[key] = uid # in case of TP=2, local device 0 and 2 will be mapped to same key.
uid += 1
return result


Expand All @@ -36,19 +36,19 @@ def process_mesh_mapping(mesh) -> dict[int, int]:
"""
devices = mesh.devices
result = {} # maps process index to leftmost process index in DP/FSDP group
i = 0
uid = 0
leftmost2uid = {}

for i in range(jax.process_count()):
tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, i)]
for process_index in range(jax.process_count()):
tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, process_index)]
tmp[-1] = 0 # we want the device with TP group index 0 in the same DP/FSDP group
upper_left_position = tuple(tmp) # in order to index into devices
upper_left_process = devices[upper_left_position].process_index
# assign uid to each process that has a device with TP group index 0
if upper_left_process not in leftmost2uid:
leftmost2uid[upper_left_process] = i
i += 1
leftmost2uid[upper_left_process] = uid
uid += 1
uid = leftmost2uid[upper_left_process]
result[i] = uid
result[process_index] = uid

return result

0 comments on commit c34a573

Please sign in to comment.