Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 25, 2024
1 parent 2331ae8 commit 8b7b804
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/levanter/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
def local_device_grid_positions(mesh, process_index: Optional[int] = None) -> tuple[np.ndarray, np.ndarray]:
"""Returns a tuple of nd arrays, one for each axis, indicating the position of each device on the grid.
Analogous to what np.where would return."""
pi = process_index or jax.process_index()
if process_index is None:
process_index = jax.process_index()
# our device indices are [process_index * num_devices_per_node, (process_index + 1) * num_devices_per_node)
# we could be clever here and do math to figure out where we are in the grid, but it's simpler and less
# fragile to just search the grid for our devices
my_device_pos = np.vectorize(lambda dev: dev.process_index == pi)(mesh.devices)
my_device_pos = np.vectorize(lambda dev: dev.process_index == process_index)(mesh.devices)
return my_device_pos.nonzero()


Expand Down

0 comments on commit 8b7b804

Please sign in to comment.