diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 7e25c88ef..20a11d090 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -102,19 +102,16 @@ async def current_len(self) -> Optional[int]: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: token_arrays = await self._await_token_cache() # logger.info(f"Time to get token cache: {time.time() - time_in}") - print(f"waiting until len is at least {max(indices) + 1}") len = await self.wait_until_len_at_least(max(indices) + 1) if len is not None and len < max(indices) + 1: raise ValueError("Requested indices beyond the end of the dataset") offsets = np.array(indices) * self.seq_len - print(f"getting offsets {offsets}") with ts.Batch(): out = [] for offset in offsets: out.append(token_arrays.data[offset : offset + self.seq_len].read()) out = await asyncio.gather(*out) - print("done waiting") return out def get_batch_sync(self, indices: Sequence[int]) -> Sequence[T_co]: