Skip to content

Commit

Permalink
get rid of eraconfig b/c draccus can't handle it
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Sep 13, 2024
1 parent 5c18557 commit b6f334e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
3 changes: 1 addition & 2 deletions config/llama2_small_fast_mix.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
data:
tokenizer: "meta-llama/Llama-2-7b-hf"
cache_dir: "gs://levanter-data/new-tokenized/pile_mix/"
shuffle:
era_length: 10000
shuffle: 10000
configs:
arxiv:
train_urls:
Expand Down
17 changes: 9 additions & 8 deletions src/levanter/data/permutation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
from typing import Optional, Sequence

import jax.random
Expand Down Expand Up @@ -29,7 +28,11 @@ def is_finite(self) -> bool:
return self.dataset.is_finite()

async def current_len(self) -> Optional[int]:
return await self.dataset.current_len()
if await self.final_length_is_known():
return await self.async_len()
# In general, we can't know the current length until we know the entire length
return None
# return await self.dataset.current_len()

async def getitem_async(self, index: int) -> T_co:
permutation = await self._get_permutation()
Expand All @@ -41,9 +44,12 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:

async def _get_permutation(self):
if self._permutation is None:
self._permutation = Permutation(await self.dataset.async_len(), self.key)
self._permutation = Permutation(await self.async_len(), self.key)
return self._permutation

async def wait_until_len_at_least(self, length: int) -> int:
return await self.async_len()


class EraShufflingDataset(AsyncDataset[T_co]):
"""
Expand Down Expand Up @@ -128,8 +134,3 @@ async def wait_until_len_at_least(self, length: int) -> int:
# wait until we hit the next era
next_era_end = (length // self.era_length + 1) * self.era_length
return await self.dataset.wait_until_len_at_least(next_era_end)


@dataclasses.dataclass
class EraConfig:
era_length: int
17 changes: 9 additions & 8 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from levanter.data import AsyncDataset
from levanter.data.dataset import MappedAsyncDataset
from levanter.data.mixture import MixtureDataset, StopStrategy
from levanter.data.permutation import EraConfig

# intercept the logging nonsense here
from levanter.logging import silence_transformer_nag # noqa
Expand Down Expand Up @@ -103,17 +102,19 @@ async def current_len(self) -> Optional[int]:
async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
token_arrays = await self._await_token_cache()
# logger.info(f"Time to get token cache: {time.time() - time_in}")
print(f"waiting until len is at least {max(indices) + 1}")
len = await self.wait_until_len_at_least(max(indices) + 1)
if len is not None and len < max(indices) + 1:
raise ValueError("Requested indices beyond the end of the dataset")
offsets = np.array(indices) * self.seq_len
print(f"getting offsets {offsets}")
with ts.Batch():
out = []
for offset in offsets:
out.append(token_arrays.data[offset : offset + self.seq_len].read())

out = await asyncio.gather(*out)

print("done waiting")
return out

def get_batch_sync(self, indices: Sequence[int]) -> Sequence[T_co]:
Expand Down Expand Up @@ -549,9 +550,9 @@ class LMTaskConfig(abc.ABC):
enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't

ignore_token_id: Optional[int] = None
shuffle: bool | EraConfig = False
shuffle: bool | int = False
"""whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle.
If you want to shuffle in eras, provide an EraConfig (which asks for an era_length)"""
If you want to shuffle in eras, set this to the era length"""

@cached_property
def the_tokenizer(self) -> PreTrainedTokenizerBase:
Expand Down Expand Up @@ -599,8 +600,8 @@ def train_set(

if self.shuffle is True:
ds = ds.shuffle(key)
elif isinstance(self.shuffle, EraConfig):
ds = ds.era_shuffle(self.shuffle.era_length, key=key)
elif isinstance(self.shuffle, int):
ds = ds.era_shuffle(self.shuffle, key=key)

return ds # type: ignore

Expand Down Expand Up @@ -754,8 +755,8 @@ def train_set(
def shuffle_ds(ds, key):
if self.shuffle is True:
ds = ds.shuffle(key)
elif isinstance(self.shuffle, EraConfig):
ds = ds.era_shuffle(self.shuffle.era_length, key=key)
elif isinstance(self.shuffle, int):
ds = ds.era_shuffle(self.shuffle, key=key)

return ds

Expand Down
15 changes: 7 additions & 8 deletions src/levanter/store/stress_test_new_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,13 @@ def ensure_cache(new_cache_path):
if __name__ == "__main__":
import sys

if not len(sys.argv) == 3:
print("Usage: convert_to_new_cache.py old_cache_path new_cache_path")
if not len(sys.argv) == 2:
print("Usage: convert_to_new_cache.py new_cache_path")
sys.exit(1)

for split in ["validation", "train"]:
print(f"Split: {split}", flush=True)
in_path = os.path.join(sys.argv[1], split)
out_path = os.path.join(sys.argv[2], split)
cache_path = os.path.join(sys.argv[1], split)
# convert_to_new_cache(in_path, out_path)
# with capture_time() as time_fn:
# bench_old_cache(in_path)
Expand All @@ -126,24 +125,24 @@ def ensure_cache(new_cache_path):
exemplar = {"input_ids": np.zeros((SEQ_LEN,), dtype=np.int32)}

with capture_time() as time_fn:
bench_new_cache_serial(exemplar, out_path)
bench_new_cache_serial(exemplar, cache_path)
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()
print(f"New Cache Serial: {time_fn()} ({tokens_per_second} tps)", flush=True)

with capture_time() as time_fn:
asyncio.run(bench_new_cache_serial_tokenseq(exemplar, out_path))
asyncio.run(bench_new_cache_serial_tokenseq(exemplar, cache_path))
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()

print(f"New Cache Serial TokenSeq: {time_fn()} ({tokens_per_second} tps)", flush=True)

with capture_time() as time_fn:
bench_new_cache_random(exemplar, out_path)
bench_new_cache_random(exemplar, cache_path)
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()

print(f"New Cache Random: {time_fn()} ({tokens_per_second} tps)", flush=True)

with capture_time() as time_fn:
asyncio.run(bench_new_cache_permutation_random(exemplar, out_path))
asyncio.run(bench_new_cache_permutation_random(exemplar, cache_path))
tokens_per_second = SEQ_LEN * BS * BATCHES / time_fn()

print(f"New Cache Permutation: {time_fn()} ({tokens_per_second} tps)", flush=True)

0 comments on commit b6f334e

Please sign in to comment.