diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index fdecfa245..c93043dec 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -2,7 +2,7 @@ import logging import time from collections import defaultdict -from typing import Iterable, Iterator, Optional, Tuple, TypeVar +from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar import jax from jax import Array @@ -20,7 +20,7 @@ from levanter.data.dataset import AsyncDataset from levanter.data.utils import batched from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape -from levanter.utils.background_iterable import BackgroundIterable +from levanter.utils.background_iterable import BackgroundIterator from levanter.utils.jax_utils import local_cpu_mesh from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait @@ -113,10 +113,11 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No self.mapping = hax.partitioning.current_thread_local_mapping() buffered_batches = self.dl.max_buffered_batches + self._batches: Iterator[Ex] if buffered_batches == 0: self._batches = AsyncIteratorWrapper(self._produce_batches()) else: - self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) + self._batches = _JaxCpuBackgroundIterator(self._produce_batches, max_capacity=buffered_batches) def __next__(self): time_start = time.time() @@ -246,18 +247,6 @@ def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: return hax.partitioning.pspec_for_axis(shape_spec.shape, self.dl.axis_resources) # type: ignore -def _abstractify(x): - def _abstractify_array(x): - if isinstance(x, jax.numpy.ndarray): - return ShapeSpec(x.shape, x.dtype) - elif isinstance(x, hax.NamedArray): - return NamedShapeSpec(x.axes, x.dtype) - - return x - - return hax.tree_util.tree_map(_abstractify_array, x) - - def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedShapeSpec: if is_named_array(leaf): return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype) @@ -265,6 +254,19 @@ def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedS return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) +class _JaxCpuBackgroundIterator(BackgroundIterator[Ex]): + """ + We want the thread to only use the CPU device. + """ + + def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], max_capacity: Optional[int]): + super().__init__(producer_fn, max_capacity) + + def _fill_queue_with_batches(self): + with local_cpu_mesh(): + super()._fill_queue_with_batches() + + @functools.partial(jax.jit, static_argnums=(0,)) def _stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves):