-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix crash in data loader caused by using stale array #765
Merged
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import logging | ||
import time | ||
from collections import defaultdict | ||
from typing import Iterable, Iterator, Optional, Tuple, TypeVar | ||
from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar | ||
|
||
import jax | ||
from jax import Array | ||
|
@@ -20,8 +20,9 @@ | |
from levanter.data.dataset import AsyncDataset | ||
from levanter.data.utils import batched | ||
from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape | ||
from levanter.utils.background_iterable import BackgroundIterable | ||
from levanter.utils.thread_utils import blocking_wait | ||
from levanter.utils.background_iterable import BackgroundIterator | ||
from levanter.utils.jax_utils import local_cpu_mesh | ||
from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait | ||
|
||
|
||
Ex = TypeVar("Ex") | ||
|
@@ -98,6 +99,8 @@ def __iter__(self): | |
return self.iter_from_step(None) | ||
|
||
def iter_from_step(self, start_from_batch: Optional[int] = None): | ||
# sometimes we pass in an array for the start_from_batch, so we need to check for that | ||
start_from_batch = int(start_from_batch) if start_from_batch is not None else None | ||
return DataLoaderIterator(self, start_from_batch=start_from_batch) | ||
|
||
|
||
|
@@ -109,115 +112,132 @@ def __init__(self, data_loader: DataLoader, start_from_batch: Optional[int] = No | |
if self.mapping is None: | ||
self.mapping = hax.partitioning.current_thread_local_mapping() | ||
|
||
# TODO: bring back non-prefetching version | ||
buffered_batches = self.dl.max_buffered_batches | ||
self._batches = iter(BackgroundIterable(self._produce_batches, max_capacity=buffered_batches)) | ||
self._batches: Iterator[Ex] | ||
if buffered_batches == 0: | ||
self._batches = AsyncIteratorWrapper(self._produce_batches()) | ||
else: | ||
self._batches = _JaxCpuBackgroundIterator(self._produce_batches, max_capacity=buffered_batches) | ||
|
||
def __next__(self): | ||
time_start = time.time() | ||
out = next(self._batches) | ||
individual_data_batch = next(self._batches) | ||
data_for_this_batch = {index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)} | ||
batch = self._batchify_local_data(data_for_this_batch) | ||
|
||
time_end = time.time() | ||
if (time_end - time_start) > 0.5: | ||
logger.info(f"Prefetch wasn't fast enough: {time_end - time_start:.3f}") | ||
return out | ||
return batch | ||
|
||
async def _produce_batches(self): | ||
batch_number = self._start_from_batch or 0 | ||
total_ex_loaded = 0 | ||
done = False | ||
while not done: | ||
next_batch_numbers = [] | ||
for i in range(self.dl.prefetch_size): | ||
if self.dl.data_store.is_finite(): | ||
next_end = (batch_number + 1) * self.dl.batch_size | ||
available_len = await self.dl.data_store.wait_until_len_at_least(next_end) | ||
if available_len < next_end: | ||
done = True | ||
break | ||
|
||
next_batch_numbers.append(batch_number) | ||
batch_number += 1 | ||
target_next_batch_number = batch_number + self.dl.prefetch_size | ||
max_achievable_batch_number = await self._dataset_get_available_batch_number(target_next_batch_number) | ||
if max_achievable_batch_number < target_next_batch_number: | ||
done = True | ||
|
||
next_batch_numbers = list(range(batch_number, min(target_next_batch_number, max_achievable_batch_number))) | ||
|
||
if len(next_batch_numbers) == 0: | ||
break | ||
|
||
batch_number = next_batch_numbers[-1] + 1 | ||
|
||
async for batch in self._retrieve_batches(next_batch_numbers): | ||
yield batch | ||
|
||
total_ex_loaded += self.dl.batch_size * len(next_batch_numbers) | ||
async def _dataset_get_available_batch_number(self, target_max_batch_number: int) -> int: | ||
if self.dl.data_store.is_finite(): | ||
next_end = (target_max_batch_number + 1) * self.dl.batch_size | ||
available_len = await self.dl.data_store.wait_until_len_at_least(next_end) | ||
max_achievable_batch_number = available_len // self.dl.batch_size | ||
|
||
async def _retrieve_batches(self, batch_numbers: list[int]): | ||
with hax.axis_mapping(self.mapping), self.dl.mesh: | ||
indices_for_this_batch_of_batches: list[int] = [] | ||
for bn in batch_numbers: | ||
indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1) | ||
indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices] | ||
indices_for_this_batch_of_batches.extend(indices_this_batch_this_process) | ||
return max_achievable_batch_number | ||
|
||
return target_max_batch_number | ||
|
||
async def _retrieve_batches(self, batch_numbers: list[int]): | ||
with local_cpu_mesh(): | ||
time_start = time.time() | ||
individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches) | ||
individual_datums_for_each_batch = await self._do_retrieve_batch_of_batches(batch_numbers) | ||
# reshape to be per batch | ||
time_end = time.time() | ||
logger.debug(f"Time to get {len(batch_numbers)} batches: {time_end - time_start:.3f}") | ||
time_start = time.time() | ||
# reshape to be per batch | ||
individual_datums = list(batched(individual_datums, len(self.dl._local_indices))) | ||
|
||
# below we're gonna get the indices relative to this batch (i.e. 0 to batch_size) | ||
index_to_datum = [ | ||
{index: datum for index, datum in zip(self.dl._local_indices, individual_data_batch)} | ||
for individual_data_batch in individual_datums | ||
] | ||
|
||
def get_local_batch(bn: int, begin: int, end: int) -> list: | ||
# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example | ||
# which will require support from the datastore (i.e. tensorstore) | ||
device_batch = _stack_tree(self.dl.Batch.name, [index_to_datum[bn][i] for i in range(begin, end)]) | ||
batch_leaves = hax.tree_util.tree_leaves(device_batch) | ||
return batch_leaves | ||
|
||
def get_local_data_for_leaf(bn, indices: _TensorSliceIndex, leaf_index: int) -> Array: | ||
batch_slice = indices[0] | ||
begin, end, stride = batch_slice.indices(self.dl.batch_size) | ||
if stride != 1: | ||
raise ValueError("Stride must be 1") | ||
|
||
leaf_data = (get_local_batch(bn, begin, end))[leaf_index] | ||
|
||
if isinstance(leaf_data, hax.NamedArray): | ||
# select out the batch axis | ||
batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes) | ||
new_indices = list(indices) | ||
new_indices[batch_index] = slice(None) | ||
return leaf_data.array[tuple(new_indices)] | ||
|
||
for data in individual_datums_for_each_batch: | ||
yield data | ||
|
||
def _batchify_local_data(self, data_for_this_batch: dict[int, Array]): | ||
cache: dict[tuple[int, int], list[Array | hax.NamedArray]] = {} | ||
|
||
def get_local_batch(begin: int, end: int) -> list: | ||
if (begin, end) in cache: | ||
return cache[(begin, end)] | ||
|
||
# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example | ||
# which will require support from the datastore (i.e. tensorstore) | ||
device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)]) | ||
batch_leaves = hax.tree_util.tree_leaves(device_batch) | ||
|
||
cache[(begin, end)] = batch_leaves | ||
|
||
return batch_leaves | ||
|
||
def get_local_data_for_leaf(indices: _TensorSliceIndex, leaf_index: int) -> Array: | ||
batch_slice = indices[0] | ||
begin, end, stride = batch_slice.indices(self.dl.batch_size) | ||
if stride != 1: | ||
raise ValueError("Stride must be 1") | ||
|
||
leaf_data = get_local_batch(begin, end)[leaf_index] | ||
|
||
if isinstance(leaf_data, hax.NamedArray): | ||
# select out the batch axis | ||
batch_index = index_where(lambda ax: ax.name == self.dl.Batch.name, leaf_data.axes) | ||
new_indices = list(indices) | ||
new_indices[batch_index] = slice(None) | ||
return leaf_data.array[tuple(new_indices)] | ||
else: | ||
other_indices = indices[1:] | ||
if all(idx == slice(None) for idx in other_indices): | ||
return leaf_data | ||
else: | ||
other_indices = indices[1:] | ||
if all(idx == slice(None) for idx in other_indices): | ||
return leaf_data | ||
else: | ||
# TODO: this doesn't work with named axes | ||
return leaf_data[(..., *other_indices)] | ||
|
||
for batch_offset, bn in enumerate(batch_numbers): | ||
|
||
def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): | ||
def get_data(indices): | ||
return get_local_data_for_leaf(batch_offset, indices, leaf_index) | ||
|
||
raw_array = jax.make_array_from_callback( | ||
to_raw_shape(item_leaf_shape), | ||
jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)), | ||
get_data, | ||
) | ||
if isinstance(item_leaf_shape, NamedShapeSpec): | ||
return hax.NamedArray(raw_array, item_leaf_shape.shape) | ||
else: | ||
return raw_array | ||
|
||
gda_leaves = [ | ||
make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf)) | ||
for leaf_index, item_leaf in enumerate(self.dl._ex_leaves) | ||
] | ||
|
||
gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves) | ||
yield gda_tree | ||
# TODO: this doesn't work with named axes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't anything with named axes go into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah old comment, sorry |
||
return leaf_data[(..., *other_indices)] | ||
|
||
def make_global_array_for_leaf(leaf_index, item_leaf_shape: ShapeSpec | NamedShapeSpec): | ||
def get_data(indices): | ||
return get_local_data_for_leaf(indices, leaf_index) | ||
|
||
raw_array = jax.make_array_from_callback( | ||
to_raw_shape(item_leaf_shape), | ||
jax.sharding.NamedSharding(self.dl.mesh, self._pspec_for(item_leaf_shape)), | ||
get_data, | ||
) | ||
if isinstance(item_leaf_shape, NamedShapeSpec): | ||
return hax.NamedArray(raw_array, item_leaf_shape.shape) | ||
else: | ||
return raw_array | ||
|
||
gda_leaves = [ | ||
make_global_array_for_leaf(leaf_index, _batchified_shape(self.dl.Batch, item_leaf)) | ||
for leaf_index, item_leaf in enumerate(self.dl._ex_leaves) | ||
] | ||
gda_tree = jax.tree.unflatten(self.dl._ex_structure, gda_leaves) | ||
return gda_tree | ||
|
||
async def _do_retrieve_batch_of_batches(self, batch_numbers): | ||
indices_for_this_batch_of_batches: list[int] = [] | ||
for bn in batch_numbers: | ||
indices_this_batch = range(bn * self.dl.batch_size, (bn + 1) * self.dl.batch_size, 1) | ||
indices_this_batch_this_process = [indices_this_batch[i] for i in self.dl._local_indices] | ||
indices_for_this_batch_of_batches.extend(indices_this_batch_this_process) | ||
individual_datums = await self.dl.data_store.get_batch(indices_for_this_batch_of_batches) | ||
individual_datums_for_each_batch = list(batched(individual_datums, len(self.dl._local_indices))) | ||
return individual_datums_for_each_batch | ||
|
||
def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: | ||
if isinstance(shape_spec, ShapeSpec): # type: ignore | ||
|
@@ -227,31 +247,24 @@ def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: | |
return hax.partitioning.pspec_for_axis(shape_spec.shape, self.dl.axis_resources) # type: ignore | ||
|
||
|
||
def _abstractify(x): | ||
def _abstractify_array(x): | ||
if isinstance(x, jax.numpy.ndarray): | ||
return ShapeSpec(x.shape, x.dtype) | ||
elif isinstance(x, hax.NamedArray): | ||
return NamedShapeSpec(x.axes, x.dtype) | ||
|
||
return x | ||
|
||
return hax.tree_util.tree_map(_abstractify_array, x) | ||
|
||
|
||
def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedShapeSpec: | ||
if is_named_array(leaf): | ||
return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype) | ||
else: | ||
return ShapeSpec((Batch.size,) + leaf.shape, leaf.dtype) | ||
|
||
|
||
def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: | ||
if isinstance(shape_spec, ShapeSpec): # type: ignore | ||
batch_name = hax.partitioning.physical_axis_name(self.Batch, self.axis_resources) | ||
return PartitionSpec(batch_name, *((None,) * (len(shape_spec.shape) - 1))) | ||
else: | ||
return hax.partitioning.pspec_for_axis(shape_spec.shape, self.axis_resources) # type: ignore | ||
class _JaxCpuBackgroundIterator(BackgroundIterator[Ex]): | ||
""" | ||
We want the thread to only use the CPU device. | ||
""" | ||
|
||
def __init__(self, producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]], max_capacity: Optional[int]): | ||
super().__init__(producer_fn, max_capacity) | ||
|
||
def _fill_queue_with_batches(self): | ||
with local_cpu_mesh(): | ||
super()._fill_queue_with_batches() | ||
|
||
|
||
@functools.partial(jax.jit, static_argnums=(0,)) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to account for the case the comment describes in the MyPy type hint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i could make it be an jnp.ndarray, but i'd rather we didn't pass it in and i'm vaguely annoyed mypy isn't raising an error already.