From 33130028a60eb14c49cd13becb6097b24519d6aa Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 24 May 2024 17:55:41 -0700 Subject: [PATCH] small fix --- src/levanter/mesh.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/levanter/mesh.py b/src/levanter/mesh.py index ea6636db3..7b11f8e5d 100644 --- a/src/levanter/mesh.py +++ b/src/levanter/mesh.py @@ -40,14 +40,15 @@ def process_mesh_mapping(mesh) -> dict[int, int]: leftmost2uid = {} for i in range(jax.process_count()): - upper_left_position = tuple([np.min(axis) for axis in local_device_grid_positions(mesh, i)]) - upper_left_position[2][...] = 0 # we want the device with TP group index 0 in the same DP/FSDP group + tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, i)] + tmp[-1] = 0 # we want the device with TP group index 0 in the same DP/FSDP group + upper_left_position = tuple(tmp) # in order to index into devices upper_left_process = devices[upper_left_position].process_index # assign uid to each process that has a device with TP group index 0 if upper_left_process not in leftmost2uid: - leftmost2uid[upper_left_position] = i + leftmost2uid[upper_left_process] = i i += 1 - uid = leftmost2uid[upper_left_position] + uid = leftmost2uid[upper_left_process] result[i] = uid return result