Skip to content

Commit

Permalink
bump haliax version
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 17, 2024
1 parent 89d18c8 commit 803df3d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
# "haliax>=1.3,<2.0",
# Haliax changes in step with levanter, so we'll just use the git version except for releases.
# "haliax @ git+https://github.com/stanford-crfm/haliax.git@main",
"haliax>=1.4.dev291",
"haliax>=1.4.dev296",
"equinox>=0.11.4",
"jaxtyping>=0.2.20",
"transformers>=4.39.3",
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Arra
def make_global_array_for_leaf(leaf_index, item_leaf_shape: Union[ShapeSpec, NamedShapeSpec]):
devices = jax.devices()
devices = np.array(devices).reshape(*self.mesh.devices.shape)
contiguous_mesh = jax.sharding.Mesh(devices, ("replica", ResourceAxis.DATA, ResourceAxis.MODEL))
contiguous_mesh = jax.sharding.Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))
raw_array = jax.make_array_from_callback(
to_raw_shape(item_leaf_shape),
jax.sharding.NamedSharding(contiguous_mesh, self._pspec_for(item_leaf_shape)),
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def device_mesh(self) -> Mesh:
# devices = np.array(devices).reshape(
# self.replica_axis_size, self.data_axis_size // self.replica_axis_size, self.model_axis_size
# )
return Mesh(devices, ("replica", ResourceAxis.DATA, ResourceAxis.MODEL))
return Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))

@property
def eval_batch_size(self):
Expand Down Expand Up @@ -702,7 +702,7 @@ def compute_axis_mapping(self) -> ResourceMapping:
axes_to_return[axis] = ResourceAxis.MODEL

if self.batch_axis is not None:
axes_to_return[self.batch_axis] = ("replica", ResourceAxis.DATA)
axes_to_return[self.batch_axis] = (ResourceAxis.REPLICA, ResourceAxis.DATA)

return axes_to_return

Expand Down

0 comments on commit 803df3d

Please sign in to comment.