diff --git a/config/llama2_small_fast_mix.yaml b/config/llama2_small_fast_mix.yaml index aabd17fae..29b0a8a52 100644 --- a/config/llama2_small_fast_mix.yaml +++ b/config/llama2_small_fast_mix.yaml @@ -1,8 +1,7 @@ data: tokenizer: "meta-llama/Llama-2-7b-hf" cache_dir: "gs://levanter-data/new-tokenized/pile_mix/" - shuffle: - era_length: 10000 + shuffle: 10000 configs: arxiv: train_urls: diff --git a/src/levanter/data/permutation.py b/src/levanter/data/permutation.py index a0f0566f4..6599d4974 100644 --- a/src/levanter/data/permutation.py +++ b/src/levanter/data/permutation.py @@ -1,4 +1,3 @@ -import dataclasses from typing import Optional, Sequence import jax.random @@ -29,7 +28,11 @@ def is_finite(self) -> bool: return self.dataset.is_finite() async def current_len(self) -> Optional[int]: - return await self.dataset.current_len() + if await self.final_length_is_known(): + return await self.async_len() + # In general, we can't know the current length until we know the entire length + return None + # return await self.dataset.current_len() async def getitem_async(self, index: int) -> T_co: permutation = await self._get_permutation() @@ -41,9 +44,12 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: async def _get_permutation(self): if self._permutation is None: - self._permutation = Permutation(await self.dataset.async_len(), self.key) + self._permutation = Permutation(await self.async_len(), self.key) return self._permutation + async def wait_until_len_at_least(self, length: int) -> int: + return await self.async_len() + class EraShufflingDataset(AsyncDataset[T_co]): """ @@ -128,8 +134,3 @@ async def wait_until_len_at_least(self, length: int) -> int: # wait until we hit the next era next_era_end = (length // self.era_length + 1) * self.era_length return await self.dataset.wait_until_len_at_least(next_era_end) - - -@dataclasses.dataclass -class EraConfig: - era_length: int diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index fc9ce8052..ad03f6d01 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -29,7 +29,6 @@ from levanter.data import AsyncDataset from levanter.data.dataset import MappedAsyncDataset from levanter.data.mixture import MixtureDataset, StopStrategy -from levanter.data.permutation import EraConfig # intercept the logging nonsense here from levanter.logging import silence_transformer_nag # noqa @@ -103,17 +102,19 @@ async def current_len(self) -> Optional[int]: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: token_arrays = await self._await_token_cache() # logger.info(f"Time to get token cache: {time.time() - time_in}") + print(f"waiting until len is at least {max(indices) + 1}") len = await self.wait_until_len_at_least(max(indices) + 1) if len is not None and len < max(indices) + 1: raise ValueError("Requested indices beyond the end of the dataset") offsets = np.array(indices) * self.seq_len + print(f"getting offsets {offsets}") with ts.Batch(): out = [] for offset in offsets: out.append(token_arrays.data[offset : offset + self.seq_len].read()) out = await asyncio.gather(*out) - + print("done waiting") return out def get_batch_sync(self, indices: Sequence[int]) -> Sequence[T_co]: @@ -549,9 +550,9 @@ class LMTaskConfig(abc.ABC): enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't ignore_token_id: Optional[int] = None - shuffle: bool | EraConfig = False + shuffle: bool | int = False """whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. - If you want to shuffle in eras, provide an EraConfig (which asks for an era_length)""" + If you want to shuffle in eras, set this to the era length""" @cached_property def the_tokenizer(self) -> PreTrainedTokenizerBase: @@ -599,8 +600,8 @@ def train_set( if self.shuffle is True: ds = ds.shuffle(key) - elif isinstance(self.shuffle, EraConfig): - ds = ds.era_shuffle(self.shuffle.era_length, key=key) + elif isinstance(self.shuffle, int): + ds = ds.era_shuffle(self.shuffle, key=key) return ds # type: ignore @@ -754,8 +755,8 @@ def train_set( def shuffle_ds(ds, key): if self.shuffle is True: ds = ds.shuffle(key) - elif isinstance(self.shuffle, EraConfig): - ds = ds.era_shuffle(self.shuffle.era_length, key=key) + elif isinstance(self.shuffle, int): + ds = ds.era_shuffle(self.shuffle, key=key) return ds diff --git a/src/levanter/store/stress_test_new_cache.py b/src/levanter/store/stress_test_new_cache.py index c583ede56..66d002abd 100644 --- a/src/levanter/store/stress_test_new_cache.py +++ b/src/levanter/store/stress_test_new_cache.py @@ -109,14 +109,13 @@ def ensure_cache(new_cache_path): if __name__ == "__main__": import sys - if not len(sys.argv) == 3: - print("Usage: convert_to_new_cache.py old_cache_path new_cache_path") + if not len(sys.argv) == 2: + print("Usage: convert_to_new_cache.py new_cache_path") sys.exit(1) for split in ["validation", "train"]: print(f"Split: {split}", flush=True) - in_path = os.path.join(sys.argv[1], split) - out_path = os.path.join(sys.argv[2], split) + cache_path = os.path.join(sys.argv[1], split) # convert_to_new_cache(in_path, out_path) # with capture_time() as time_fn: # bench_old_cache(in_path) @@ -126,24 +125,24 @@ def ensure_cache(new_cache_path): exemplar = {"input_ids": np.zeros((SEQ_LEN,), dtype=np.int32)} with capture_time() as time_fn: - bench_new_cache_serial(exemplar, out_path) + bench_new_cache_serial(exemplar, cache_path) tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() print(f"New Cache Serial: {time_fn()} ({tokens_per_second} tps)", flush=True) with capture_time() as time_fn: - asyncio.run(bench_new_cache_serial_tokenseq(exemplar, out_path)) + asyncio.run(bench_new_cache_serial_tokenseq(exemplar, cache_path)) tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() print(f"New Cache Serial TokenSeq: {time_fn()} ({tokens_per_second} tps)", flush=True) with capture_time() as time_fn: - bench_new_cache_random(exemplar, out_path) + bench_new_cache_random(exemplar, cache_path) tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() print(f"New Cache Random: {time_fn()} ({tokens_per_second} tps)", flush=True) with capture_time() as time_fn: - asyncio.run(bench_new_cache_permutation_random(exemplar, out_path)) + asyncio.run(bench_new_cache_permutation_random(exemplar, cache_path)) tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn() print(f"New Cache Permutation: {time_fn()} ({tokens_per_second} tps)", flush=True)