Skip to content

Commit

Permalink
Partially-sharded-data-parallel (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ authored May 28, 2024
1 parent 2bb1252 commit ed3c6f1
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 66 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
# "haliax>=1.3,<2.0",
# Haliax changes in step with levanter, so we'll just use the git version except for releases.
# "haliax @ git+https://github.com/stanford-crfm/haliax.git@main",
"haliax>=1.4.dev291",
"haliax>=1.4.dev296",
"equinox>=0.11.4",
"jaxtyping>=0.2.20",
"tokenizers>=0.15.2",
Expand Down
25 changes: 20 additions & 5 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from haliax.partitioning import ResourceMapping
from haliax.util import is_named_array

import levanter.mesh
from levanter.data import Dataset
from levanter.data.dataset import ShardableDataset
from levanter.mesh import local_devices_mapping, process_mesh_mapping
from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape
from levanter.utils.background_iterable import BackgroundIterable
from levanter.utils.py_utils import non_caching_cycle
Expand Down Expand Up @@ -172,8 +172,10 @@ def __init__(
self.mesh = mesh
self.Batch = Batch

process_data_pos = override_process_data_pos or levanter.mesh.process_mesh_position(mesh)[0]
num_data_process_groups = override_process_data_groups or levanter.mesh.process_mesh_size(mesh)[0]
process_mesh_map = process_mesh_mapping(self.mesh)
local_devices_map = local_devices_mapping(self.mesh)
process_data_pos = override_process_data_pos or process_mesh_map[jax.process_index()]
num_data_process_groups = override_process_data_groups or max(process_mesh_map.values()) + 1

if not override_process_data_groups:
assert num_data_process_groups <= jax.process_count()
Expand All @@ -182,20 +184,33 @@ def __init__(
self.num_data_process_groups = num_data_process_groups
assert self.Batch.size % num_data_process_groups == 0

self.process_mesh_map = process_mesh_map
self.local_devices_map = local_devices_map
self.per_device_batch_size = self.batch_size // self.mesh.devices.shape[0] // self.mesh.devices.shape[1]

self.item_dataset = local_dataset.shard(process_data_pos, num_data_process_groups)
super().__init__(max_capacity, axis_resources)

def _produce_batches(self) -> Iterator[PyTree]:
one_item_generator = non_caching_cycle(self.item_dataset)
batched = _batched(one_item_generator, self.local_batch_size)

def batch_callback(global_begin, _):
# global_begin is uid for DP/FSDP
# DP_id * per_device_bs = global_begin
device_pos = global_begin // self.per_device_batch_size

begin = self.local_devices_map[device_pos] * self.per_device_batch_size
end = begin + self.per_device_batch_size

return local_batch[begin:end]

while True:
batch_offset = self.process_data_pos * self.local_batch_size
local_batch: List[PyTree] = next(batched)

batch = self._construct_global_array_for_tree(
item_exemplar=local_batch[0],
get_batch_items=lambda begin, end: local_batch[(begin - batch_offset) : (end - batch_offset)],
get_batch_items=batch_callback,
)

yield batch
Expand Down
56 changes: 39 additions & 17 deletions src/levanter/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,51 @@
def local_device_grid_positions(mesh, process_index: Optional[int] = None) -> tuple[np.ndarray, np.ndarray]:
"""Returns a tuple of nd arrays, one for each axis, indicating the position of each device on the grid.
Analogous to what np.where would return."""
pi = process_index or jax.process_index()
# our device indices are [process_index * num_devices_per_node, (process_index + 1) * num_devices_per_node)
# we could be clever here and do math to figure out where we are in the grid, but it's simpler and less
# fragile to just search the grid for our devices
my_device_pos = np.vectorize(lambda dev: dev.process_index == pi)(mesh.devices)
if process_index is None:
process_index = jax.process_index()

my_device_pos = np.vectorize(lambda dev: dev.process_index == process_index)(mesh.devices)
return my_device_pos.nonzero()


def process_mesh_position(mesh, process_index: Optional[int] = None) -> tuple[int, ...]:
def local_devices_mapping(mesh: Mesh, process_index: Optional[int] = None) -> dict[int, int]:
"""
If we envision each process as a subgrid of the mesh for its devices, this is the position of the process
in the coarsened process-level mesh
Handles the case when different devices in same process share the same data in TP.
Returns a mapping from local devices' DP/FSDP group index in global mesh to local indices
"""
upper_left_position = np.array([np.min(axis) for axis in local_device_grid_positions(mesh, process_index)])
local_mesh_size = mesh.local_mesh.devices.shape
pos = upper_left_position // local_mesh_size
return pos
local_device_pos = local_device_grid_positions(mesh, process_index)[:2] # first 2 axes are DP axes.
result = {}
uid = 0
for local_device_index in range(len(local_device_pos[0])):
key = local_device_pos[0][local_device_index] * mesh.devices.shape[1] + local_device_pos[1][local_device_index]
if key not in result:
# when two devices maps to the same key (different TP index), they will get the same data
result[key] = uid
uid += 1
return result


def process_mesh_size(mesh: Mesh) -> tuple[int, ...]:
def process_mesh_mapping(mesh) -> dict[int, int]:
"""
If we envision each process as a subgrid of the mesh for its devices, then there is a process grid that
is a coarsened version of the mesh. This is the size of the process grid.
Handles the case when different processes share the same data in TP.
If we envision each process as a subgrid of the mesh for its devices, this is the position of the process
in the coarsened process-level mesh
"""
local_mesh_size = mesh.local_mesh.devices.shape
return tuple(mesh.devices.shape[i] // local_mesh_size[i] for i in range(len(local_mesh_size)))
devices = mesh.devices
result = {}
uid = 0
leftmost2uid = {}
# basic logic: process index -> upper-left device -> TP index 0 device -> process index -> uid
for process_index in range(jax.process_count()):
tmp = [np.min(axis) for axis in local_device_grid_positions(mesh, process_index)]
tmp[-1] = 0 # we want the device with TP group index 0 in the same DP/FSDP group
upper_left_position = tuple(tmp) # in order to index into devices
upper_left_process = devices[upper_left_position].process_index
# assign uid to each process that has a device with TP group index 0
if upper_left_process not in leftmost2uid:
leftmost2uid[upper_left_process] = uid
uid += 1
this_uid = leftmost2uid[upper_left_process]
result[process_index] = this_uid

return result
19 changes: 15 additions & 4 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,10 +797,21 @@ def _tpu_splash_attention(
raise ValueError(f"Embedding axes must be the same for q, k, and v: {q_class['D']} != {k_class['D']}")

def _physical_axis_for_binning(d):
b_out = tuple(ax for ax in pspec_for_axis(d["B"]) if ax is not None) or None
h_out = tuple(ax for ax in pspec_for_axis(d["H"]) if ax is not None) or None
s_out = tuple(ax for ax in pspec_for_axis(d["S"]) if ax is not None) or None
d_out = tuple(ax for ax in pspec_for_axis(d["D"]) if ax is not None) or None
def flatten(axes):
if axes is None:
return axes
result = []
for ax in axes:
if isinstance(ax, tuple):
result += list(ax)
else:
result.append(ax)
return tuple(result)

b_out = flatten(tuple(ax for ax in pspec_for_axis(d["B"]) if ax is not None) or None)
h_out = flatten(tuple(ax for ax in pspec_for_axis(d["H"]) if ax is not None) or None)
s_out = flatten(tuple(ax for ax in pspec_for_axis(d["S"]) if ax is not None) or None)
d_out = flatten(tuple(ax for ax in pspec_for_axis(d["D"]) if ax is not None) or None)

return PartitionSpec(b_out, h_out, s_out, d_out)

Expand Down
68 changes: 57 additions & 11 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import jmp
import numpy as np
from draccus import field
from jax.experimental import multihost_utils
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh
from jaxtyping import PRNGKeyArray, PyTree
from optax import GradientTransformation
Expand Down Expand Up @@ -549,12 +549,19 @@ class TrainerConfig:
fsdp_axis: Optional[Union[str, List[str]]] = "embed" # Axis/Axes to use for FSDP
tensor_parallel_axes: Optional[List[str]] = None # Axes, if any, to use for tensor parallelism

# TODO: in theory we can support tuples of physical axis names, but I don't think anyone actually uses that.
axis_resources: Mapping[str, str] = field(default_factory=dict)
axis_resources: Mapping[str, Union[Tuple[str], str]] = field(default_factory=dict)
"""mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred"""
parameter_axis_resources: Mapping[str, str] = field(default_factory=dict) # overrides axis_mapping for parameter
parameter_axis_resources: Mapping[str, Union[Tuple[str], str]] = field(
default_factory=dict
) # overrides axis_mapping for parameter
"""logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred"""
model_axis_size: int = 1 # how many devices to shard each model over. Data axis is the other axis

"""Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings https://cloud.google.com/tpu/docs/multislice-introduction"""
replica_ici_axis_size: int = 1
model_axis_size: int = 1
"""how many devices within each slice for sharding with DP. Fix TP=1, the rest of the devices is for FSDP."""
replica_dcn_axis_size: int = 1
"""how many slices in the multislice scheme for sharding with DP and TP. The rest of the devices is for FSDP."""

# Config related to batch sizes
train_batch_size: int = 512
Expand Down Expand Up @@ -636,19 +643,58 @@ def initialize(self):

@cached_property
def device_mesh(self) -> Mesh:
devices = jax.devices()
devices = np.array(devices).reshape(self.data_axis_size, self.model_axis_size)
return Mesh(devices, (ResourceAxis.DATA, ResourceAxis.MODEL))
is_multislice = hasattr(jax.devices()[0], "slice_index")
if is_multislice:
devices = mesh_utils.create_hybrid_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
(self.replica_dcn_axis_size, self.data_dcn_axis_size, 1),
allow_split_physical_axes=True,
)
else:
devices = mesh_utils.create_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
allow_split_physical_axes=True,
)

return Mesh(devices, (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL))

@property
def eval_batch_size(self):
return self.per_device_eval_parallelism * self.data_axis_size

@cached_property
def num_slices(self):
"""number of nodes"""
return max(getattr(device, "slice_index", 0) for device in jax.devices()) + 1

@property
def num_devices_per_slice(self):
"""number of devices within a slice"""
return jax.device_count() // self.num_slices

@property
def data_ici_axis_size(self):
"""size of the FSDP axis within slices"""
assert self.num_devices_per_slice % (self.replica_ici_axis_size * self.model_axis_size) == 0
return self.num_devices_per_slice // (self.replica_ici_axis_size * self.model_axis_size)

@property
def data_dcn_axis_size(self):
"""size of the FSDP axis across slices"""
assert self.num_slices % self.replica_dcn_axis_size == 0
return self.num_slices // self.replica_dcn_axis_size

@property
def data_axis_size(self):
"""size of the data parallel/batch parallel axis."""
assert jax.device_count() % self.model_axis_size == 0
return jax.device_count() // self.model_axis_size
return (
self.data_dcn_axis_size * self.data_ici_axis_size * self.replica_dcn_axis_size * self.replica_ici_axis_size
)

@property
def replica_axis_size(self):
"""size of the data parallel/batch parallel axis."""
return self.replica_dcn_axis_size * self.replica_ici_axis_size

@cached_property
def compute_axis_mapping(self) -> ResourceMapping:
Expand All @@ -662,7 +708,7 @@ def compute_axis_mapping(self) -> ResourceMapping:
axes_to_return[axis] = ResourceAxis.MODEL

if self.batch_axis is not None:
axes_to_return[self.batch_axis] = ResourceAxis.DATA
axes_to_return[self.batch_axis] = (ResourceAxis.REPLICA, ResourceAxis.DATA) # type: ignore

return axes_to_return

Expand Down
4 changes: 2 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def test_new_style_axis_mapping():

assert config.tensor_parallel_axes == ["a1", "a2"]
assert config.compute_axis_mapping == {
"batch": ResourceAxis.DATA,
"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA),
"a1": ResourceAxis.MODEL,
"a2": ResourceAxis.MODEL,
}
assert config.parameter_axis_mapping == {
"embed": ResourceAxis.DATA,
"a1": ResourceAxis.MODEL,
"a2": ResourceAxis.MODEL,
"batch": ResourceAxis.DATA,
"batch": (ResourceAxis.REPLICA, ResourceAxis.DATA),
}


Expand Down
Loading

0 comments on commit ed3c6f1

Please sign in to comment.