Skip to content

Commit

Permalink
Fix "best effort" hf sharding now that we have fancy meshes (#622)
Browse files Browse the repository at this point in the history
dlwh authored Jun 11, 2024
1 parent d0a8f01 commit 7bdd375
Showing 4 changed files with 114 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ dependencies = [
"equinox>=0.11.4",
"jaxtyping>=0.2.20",
"tokenizers>=0.15.2",
"transformers>=4.39.3",
"transformers>=4.41.2",
"optax>=0.1.9",
"wandb>=0.16.6,<0.18.0",
# We don't actually directly depend on scipy, but recent JAX had an issue
24 changes: 9 additions & 15 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
import jmp
import numpy as np
from draccus import field
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental import multihost_utils
from jax.sharding import Mesh
from jaxtyping import PRNGKeyArray, PyTree
from optax import GradientTransformation
@@ -56,6 +56,7 @@
from levanter.trainer_state import TrainerState, saveable_training_mask
from levanter.types import ComputeLossFunction, FilterSpec, ModuleComputeLoss
from levanter.utils import cloud_utils, fsspec_utils
from levanter.utils.jax_utils import create_fsdp_mesh
from levanter.utils.tree_utils import inference_mode


@@ -643,20 +644,13 @@ def initialize(self):

@cached_property
def device_mesh(self) -> Mesh:
is_multislice = hasattr(jax.devices()[0], "slice_index")
if is_multislice:
devices = mesh_utils.create_hybrid_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
(self.replica_dcn_axis_size, self.data_dcn_axis_size, 1),
allow_split_physical_axes=True,
)
else:
devices = mesh_utils.create_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
allow_split_physical_axes=True,
)

return Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))
return create_fsdp_mesh(
self.replica_ici_axis_size,
self.data_ici_axis_size,
self.model_axis_size,
self.replica_dcn_axis_size,
self.data_dcn_axis_size,
)

@property
def eval_batch_size(self):
80 changes: 66 additions & 14 deletions src/levanter/utils/jax_utils.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,13 @@
import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import PositionalSharding
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec, PositionalSharding
from jaxtyping import PRNGKeyArray, PyTree

import haliax
from haliax.jax_utils import is_jax_array_like
from haliax.partitioning import ResourceAxis


X = TypeVar("X")
@@ -239,24 +242,73 @@ def as_arrayish(x):
return jnp.asarray(x)


def best_effort_sharding(shape, devices=None):
def best_effort_sharding(shape, *, devices=None, mesh=None):
if hasattr(shape, "shape"):
shape = shape.shape

if devices is None:
devices = jax.devices()
device_shape = (len(devices),)
# we want to shard an array with shape shape across len(devices)
# each axis in the array has to be divisible by the corresponding axis in device_shape, so
# we iterate from the right, taking the gcd of the shape and the left-most axis of device_shape
for i in range(len(shape) - 1, -1, -1):
shape_i = shape[i]
device_shape_i = device_shape[0]
gcd = np.gcd(shape_i, device_shape_i)
device_shape_i //= gcd
device_shape = (device_shape_i, gcd) + device_shape[1:]
sharding = PositionalSharding(devices).reshape(list(device_shape)).replicate(axis=0, keepdims=True)
return sharding

if mesh is None:
mesh = haliax.partitioning._get_mesh()
if mesh.devices.shape == ():
mesh = None

if mesh is None:
device_shape = (len(devices),)
# we want to shard an array with shape shape across len(devices)
# each axis in the array has to be divisible by the corresponding axis in device_shape, so
# we iterate from the right, taking the gcd of the shape and the left-most axis of device_shape
num_devices = device_shape[0]

for i in range(len(shape) - 1, -1, -1):
shape_i = shape[i]
gcd = np.gcd(shape_i, num_devices)
num_devices //= gcd
device_shape = (num_devices, gcd) + device_shape[1:]
sharding = PositionalSharding(devices).reshape(list(device_shape)).replicate(axis=0, keepdims=True)
return sharding
else:
# get the existing mesh and find the FSDP axis
fsdp_axis = mesh.axis_names.index(haliax.partitioning.ResourceAxis.DATA)
num_devices = mesh.devices.shape[fsdp_axis]

for i in range(len(shape) - 1, -1, -1):
shape_i = shape[i]
if shape_i % num_devices == 0:
sharded_axis = i
break
else:
return NamedSharding(mesh, PartitionSpec(None))

axis_sharding = [None] * len(shape)
axis_sharding[sharded_axis] = haliax.partitioning.ResourceAxis.DATA
sharding = NamedSharding(mesh, PartitionSpec(*axis_sharding))

return sharding


def create_fsdp_mesh(
replica_ici_axis_size: int,
data_ici_axis_size: int,
model_axis_size: int,
replica_dcn_axis_size: int = 1,
data_dcn_axis_size: int = 1,
):
is_multislice = hasattr(jax.devices()[0], "slice_index")
if is_multislice:
devices = mesh_utils.create_hybrid_device_mesh(
(replica_ici_axis_size, data_ici_axis_size, model_axis_size),
(replica_dcn_axis_size, data_dcn_axis_size, 1),
allow_split_physical_axes=True,
)
else:
devices = mesh_utils.create_device_mesh(
(replica_ici_axis_size, data_ici_axis_size, model_axis_size),
allow_split_physical_axes=True,
)

return Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))


def estimated_free_device_memory(device) -> Optional[float]:
40 changes: 38 additions & 2 deletions tests/test_jax_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import jax
import numpy as np
import pytest

from levanter.utils.jax_utils import best_effort_sharding
from levanter.utils.jax_utils import best_effort_sharding, create_fsdp_mesh
from test_utils import skip_if_not_enough_devices


def _assert_can_put_with_sharding(array, sharding):
try:
jax.device_put(array, sharding)
except ValueError:
assert False, f"Could not put array with shape {array.shape} with sharding {sharding}"
# assert False, f"Could not put array with shape {array.shape} with sharding {sharding}"
raise AssertionError(f"Could not put array with shape {array.shape} with sharding {sharding}")


@skip_if_not_enough_devices(8)
@@ -44,3 +46,37 @@ def test_best_effort_sharding():
array = array.reshape(2, 2, 2)
sharding = best_effort_sharding(array.shape, devices=devices)
_assert_can_put_with_sharding(array, sharding)


@pytest.mark.parametrize("fsdp_size", [1, 2, 4, 8])
def test_best_effort_sharding_with_mesh(fsdp_size):
if fsdp_size > len(jax.devices()):
pytest.skip("Not enough devices")
elif len(jax.devices()) % fsdp_size != 0:
pytest.skip("Number of devices is not a multiple of fsdp_size")

mesh = create_fsdp_mesh(len(jax.devices()) // fsdp_size, fsdp_size, 1)

array = np.arange(8)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(2, 4)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(4, 2)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(8, 1)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(1, 8)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(2, 2, 2)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

0 comments on commit 7bdd375

Please sign in to comment.