diff --git a/src/levanter/mesh.py b/src/levanter/mesh.py index 7b11f8e5d..53d803fbe 100644 --- a/src/levanter/mesh.py +++ b/src/levanter/mesh.py @@ -20,12 +20,12 @@ def local_devices_mapping(mesh: Mesh, process_index: Optional[int] = None) -> di """Returns a mapping from local devices' DP/FSDP group index in global mesh to local indices""" local_device_pos = local_device_grid_positions(mesh, process_index)[:2] # first 2 axes are DP axes. result = {} - j = 0 - for i in range(len(local_device_pos[0])): - key = local_device_pos[0][i] * mesh.devices.shape[1] + local_device_pos[1][i] + uid = 0 + for local_device_index in range(len(local_device_pos[0])): + key = local_device_pos[0][local_device_index] * mesh.devices.shape[1] + local_device_pos[1][local_device_index] if key not in result: - result[key] = j # in case of TP=2, local device 0 and 2 will be mapped to same key. - j += 1 + result[key] = uid # in case of TP=2, local device 0 and 2 will be mapped to same key. + uid += 1 return result @@ -36,19 +36,19 @@ def process_mesh_mapping(mesh) -> dict[int, int]: """ devices = mesh.devices result = {} # maps process index to leftmost process index in DP/FSDP group - i = 0 + uid = 0 leftmost2uid = {} - for i in range(jax.process_count()): - tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, i)] + for process_index in range(jax.process_count()): + tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, process_index)] tmp[-1] = 0 # we want the device with TP group index 0 in the same DP/FSDP group upper_left_position = tuple(tmp) # in order to index into devices upper_left_process = devices[upper_left_position].process_index # assign uid to each process that has a device with TP group index 0 if upper_left_process not in leftmost2uid: - leftmost2uid[upper_left_process] = i - i += 1 + leftmost2uid[upper_left_process] = uid + uid += 1 uid = leftmost2uid[upper_left_process] - result[i] = uid + result[process_index] = uid return result