Skip to content

Commit

Permalink
try this
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 14, 2024
1 parent fef836a commit 7e278b5
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import time
from collections import defaultdict
from typing import Iterable, Iterator, Optional, Tuple, TypeVar
from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar

import jax
from jax import Array
Expand All @@ -20,7 +20,7 @@
from levanter.data.dataset import AsyncDataset
from levanter.data.utils import batched
from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape
from levanter.utils.background_iterable import BackgroundIterable
from levanter.utils.background_iterable import BackgroundIterator
from levanter.utils.jax_utils import local_cpu_mesh
from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait

Expand Down Expand Up @@ -113,10 +113,11 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No
self.mapping = hax.partitioning.current_thread_local_mapping()

buffered_batches = self.dl.max_buffered_batches
self._batches: Iterator[Ex]
if buffered_batches == 0:
self._batches = AsyncIteratorWrapper(self._produce_batches())
else:
self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches))
self._batches = _JaxCpuBackgroundIterator(self._produce_batches, max_capacity=buffered_batches)

def __next__(self):
time_start = time.time()
Expand Down Expand Up @@ -246,25 +247,26 @@ def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec:
return hax.partitioning.pspec_for_axis(shape_spec.shape, self.dl.axis_resources) # type: ignore


def _abstractify(x):
def _abstractify_array(x):
if isinstance(x, jax.numpy.ndarray):
return ShapeSpec(x.shape, x.dtype)
elif isinstance(x, hax.NamedArray):
return NamedShapeSpec(x.axes, x.dtype)

return x

return hax.tree_util.tree_map(_abstractify_array, x)


def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedShapeSpec:
if is_named_array(leaf):
return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype)
else:
return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype)


class _JaxCpuBackgroundIterator(BackgroundIterator[Ex]):
"""
We want the thread to only use the CPU device.
"""

def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], max_capacity: Optional[int]):
super().__init__(producer_fn, max_capacity)

def _fill_queue_with_batches(self):
with local_cpu_mesh():
super()._fill_queue_with_batches()


@functools.partial(jax.jit, static_argnums=(0,))
def _stack_tree(batch_name, individual_datums):
def _stack_leaves_unchecked(*leaves):
Expand Down

0 comments on commit 7e278b5

Please sign in to comment.