Skip to content

Commit

Permalink
ok this maybe fixed it?
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 14, 2024
1 parent 20a4568 commit 2747705
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,25 +131,30 @@ async def _produce_batches(self):
batch_number = self._start_from_batch or 0
done = False
while not done:
next_batch_numbers = []
for i in range(self.dl.prefetch_size):
if await self._dataset_has_enough_examples_left(batch_number):
done = True
break
target_next_batch_number = batch_number + self.dl.prefetch_size
max_achievable_batch_number = await self._dataset_get_available_batch_number(target_next_batch_number)
if max_achievable_batch_number < target_next_batch_number:
done = True

next_batch_numbers.append(batch_number)
batch_number += 1
next_batch_numbers = list(range(batch_number, min(target_next_batch_number, max_achievable_batch_number)))

if len(next_batch_numbers) == 0:
break

batch_number = next_batch_numbers[-1] + 1

async for batch in self._retrieve_batches(next_batch_numbers):
yield batch

async def _dataset_has_enough_examples_left(self, batch_number):
past_the_end = False
async def _dataset_get_available_batch_number(self, target_max_batch_number: int) -> int:
if self.dl.data_store.is_finite():
next_end = (batch_number + 1) * self.dl.batch_size
next_end = (target_max_batch_number + 1) * self.dl.batch_size
available_len = await self.dl.data_store.wait_until_len_at_least(next_end)
past_the_end = available_len < next_end
return past_the_end
max_achievable_batch_number = available_len // self.dl.batch_size

return max_achievable_batch_number

return target_max_batch_number

async def _retrieve_batches(self, batch_numbers: list[int]):
with local_cpu_mesh():
Expand Down

0 comments on commit 2747705

Please sign in to comment.