Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix "best effort" hf sharding now that we have fancy meshes #622

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"equinox>=0.11.4",
"jaxtyping>=0.2.20",
"tokenizers>=0.15.2",
"transformers>=4.39.3",
"transformers>=4.41.2",
"optax>=0.1.9",
"wandb>=0.16.6,<0.18.0",
# We don't actually directly depend on scipy, but recent JAX had an issue
Expand Down
24 changes: 9 additions & 15 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import jmp
import numpy as np
from draccus import field
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental import multihost_utils
from jax.sharding import Mesh
from jaxtyping import PRNGKeyArray, PyTree
from optax import GradientTransformation
Expand All @@ -56,6 +56,7 @@
from levanter.trainer_state import TrainerState, saveable_training_mask
from levanter.types import ComputeLossFunction, FilterSpec, ModuleComputeLoss
from levanter.utils import cloud_utils, fsspec_utils
from levanter.utils.jax_utils import create_fsdp_mesh
from levanter.utils.tree_utils import inference_mode


Expand Down Expand Up @@ -643,20 +644,13 @@ def initialize(self):

@cached_property
def device_mesh(self) -> Mesh:
is_multislice = hasattr(jax.devices()[0], "slice_index")
if is_multislice:
devices = mesh_utils.create_hybrid_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
(self.replica_dcn_axis_size, self.data_dcn_axis_size, 1),
allow_split_physical_axes=True,
)
else:
devices = mesh_utils.create_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
allow_split_physical_axes=True,
)

return Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))
return create_fsdp_mesh(
self.replica_ici_axis_size,
self.data_ici_axis_size,
self.model_axis_size,
self.replica_dcn_axis_size,
self.data_dcn_axis_size,
)

@property
def eval_batch_size(self):
Expand Down
80 changes: 66 additions & 14 deletions src/levanter/utils/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import PositionalSharding
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec, PositionalSharding
from jaxtyping import PRNGKeyArray, PyTree

import haliax
from haliax.jax_utils import is_jax_array_like
from haliax.partitioning import ResourceAxis


X = TypeVar("X")
Expand Down Expand Up @@ -239,24 +242,73 @@ def as_arrayish(x):
return jnp.asarray(x)


def best_effort_sharding(shape, devices=None):
def best_effort_sharding(shape, *, devices=None, mesh=None):
if hasattr(shape, "shape"):
shape = shape.shape

if devices is None:
devices = jax.devices()
device_shape = (len(devices),)
# we want to shard an array with shape shape across len(devices)
# each axis in the array has to be divisible by the corresponding axis in device_shape, so
# we iterate from the right, taking the gcd of the shape and the left-most axis of device_shape
for i in range(len(shape) - 1, -1, -1):
shape_i = shape[i]
device_shape_i = device_shape[0]
gcd = np.gcd(shape_i, device_shape_i)
device_shape_i //= gcd
device_shape = (device_shape_i, gcd) + device_shape[1:]
sharding = PositionalSharding(devices).reshape(list(device_shape)).replicate(axis=0, keepdims=True)
return sharding

if mesh is None:
mesh = haliax.partitioning._get_mesh()
if mesh.devices.shape == ():
mesh = None

if mesh is None:
device_shape = (len(devices),)
# we want to shard an array with shape shape across len(devices)
# each axis in the array has to be divisible by the corresponding axis in device_shape, so
# we iterate from the right, taking the gcd of the shape and the left-most axis of device_shape
num_devices = device_shape[0]

for i in range(len(shape) - 1, -1, -1):
shape_i = shape[i]
gcd = np.gcd(shape_i, num_devices)
num_devices //= gcd
device_shape = (num_devices, gcd) + device_shape[1:]
sharding = PositionalSharding(devices).reshape(list(device_shape)).replicate(axis=0, keepdims=True)
return sharding
else:
# get the existing mesh and find the FSDP axis
fsdp_axis = mesh.axis_names.index(haliax.partitioning.ResourceAxis.DATA)
num_devices = mesh.devices.shape[fsdp_axis]

for i in range(len(shape) - 1, -1, -1):
shape_i = shape[i]
if shape_i % num_devices == 0:
sharded_axis = i
break
else:
return NamedSharding(mesh, PartitionSpec(None))

axis_sharding = [None] * len(shape)
axis_sharding[sharded_axis] = haliax.partitioning.ResourceAxis.DATA
sharding = NamedSharding(mesh, PartitionSpec(*axis_sharding))

return sharding


def create_fsdp_mesh(
replica_ici_axis_size: int,
data_ici_axis_size: int,
model_axis_size: int,
replica_dcn_axis_size: int = 1,
data_dcn_axis_size: int = 1,
):
is_multislice = hasattr(jax.devices()[0], "slice_index")
if is_multislice:
devices = mesh_utils.create_hybrid_device_mesh(
(replica_ici_axis_size, data_ici_axis_size, model_axis_size),
(replica_dcn_axis_size, data_dcn_axis_size, 1),
allow_split_physical_axes=True,
)
else:
devices = mesh_utils.create_device_mesh(
(replica_ici_axis_size, data_ici_axis_size, model_axis_size),
allow_split_physical_axes=True,
)

return Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))


def estimated_free_device_memory(device) -> Optional[float]:
Expand Down
40 changes: 38 additions & 2 deletions tests/test_jax_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import jax
import numpy as np
import pytest

from levanter.utils.jax_utils import best_effort_sharding
from levanter.utils.jax_utils import best_effort_sharding, create_fsdp_mesh
from test_utils import skip_if_not_enough_devices


def _assert_can_put_with_sharding(array, sharding):
try:
jax.device_put(array, sharding)
except ValueError:
assert False, f"Could not put array with shape {array.shape} with sharding {sharding}"
# assert False, f"Could not put array with shape {array.shape} with sharding {sharding}"
raise AssertionError(f"Could not put array with shape {array.shape} with sharding {sharding}")


@skip_if_not_enough_devices(8)
Expand Down Expand Up @@ -44,3 +46,37 @@ def test_best_effort_sharding():
array = array.reshape(2, 2, 2)
sharding = best_effort_sharding(array.shape, devices=devices)
_assert_can_put_with_sharding(array, sharding)


@pytest.mark.parametrize("fsdp_size", [1, 2, 4, 8])
def test_best_effort_sharding_with_mesh(fsdp_size):
if fsdp_size > len(jax.devices()):
pytest.skip("Not enough devices")
elif len(jax.devices()) % fsdp_size != 0:
pytest.skip("Number of devices is not a multiple of fsdp_size")

mesh = create_fsdp_mesh(len(jax.devices()) // fsdp_size, fsdp_size, 1)

array = np.arange(8)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(2, 4)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(4, 2)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(8, 1)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(1, 8)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)

array = array.reshape(2, 2, 2)
sharding = best_effort_sharding(array.shape, mesh=mesh)
_assert_can_put_with_sharding(array, sharding)
Loading