diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 1d7205365..fd2e66e38 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -273,7 +273,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return sharding else: # get the existing mesh and find the FSDP axis - fsdp_axis = mesh.axis_names.index(haliax.partitioning.ResourceAxis.DATA) + fsdp_axis = mesh.axis_names.index(hax.partitioning.ResourceAxis.DATA) num_devices = mesh.devices.shape[fsdp_axis] for i in range(len(shape) - 1, -1, -1): @@ -285,7 +285,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return NamedSharding(mesh, PartitionSpec(None)) axis_sharding = [None] * len(shape) - axis_sharding[sharded_axis] = haliax.partitioning.ResourceAxis.DATA + axis_sharding[sharded_axis] = hax.partitioning.ResourceAxis.DATA sharding = NamedSharding(mesh, PartitionSpec(*axis_sharding)) return sharding