From d9c36e783fc5f6cade3e40e92bd68cde62ab0a88 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:51:28 -0800 Subject: [PATCH] missed some renames? --- src/levanter/utils/jax_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 1d7205365..fd2e66e38 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -273,7 +273,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return sharding else: # get the existing mesh and find the FSDP axis - fsdp_axis = mesh.axis_names.index(haliax.partitioning.ResourceAxis.DATA) + fsdp_axis = mesh.axis_names.index(hax.partitioning.ResourceAxis.DATA) num_devices = mesh.devices.shape[fsdp_axis] for i in range(len(shape) - 1, -1, -1): @@ -285,7 +285,7 @@ def best_effort_sharding(shape, *, devices=None, mesh=None): return NamedSharding(mesh, PartitionSpec(None)) axis_sharding = [None] * len(shape) - axis_sharding[sharded_axis] = haliax.partitioning.ResourceAxis.DATA + axis_sharding[sharded_axis] = hax.partitioning.ResourceAxis.DATA sharding = NamedSharding(mesh, PartitionSpec(*axis_sharding)) return sharding