Skip to content

Commit

Permalink
Simplify tokenization pipeline, make it work with large numbers of sh…
Browse files Browse the repository at this point in the history
…ards again, (re)add configuration metadata to cache (#752)

Co-authored-by: Ahmed Ahmed <[email protected]>
  • Loading branch information
dlwh and ahmeda14960 authored Oct 4, 2024
1 parent 71bd696 commit b41838f
Show file tree
Hide file tree
Showing 22 changed files with 1,762 additions and 1,862 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,4 @@ dmypy.json

# local execution commands
local_*.sh
.aider*
78 changes: 78 additions & 0 deletions config/data/dclm_gpt_neo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
cache_dir: "gs://marin-us-central2/tokenized/gpt_neox/"
tokenizer: "EleutherAI/gpt-neox-20b"
cache_options:
batch_size: 256
num_shard_groups: 1024
stop_strategy: restart
shuffle: 100000
configs:
"dclm":
train_urls:
- gs://marin-us-central2/raw/dclm/v2024-07-09-baseline-dedup/**/*.zstd
# these are just for eval
"paloma/4chan":
validation_urls:
- gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz
"paloma/c4_100_domains":
validation_urls:
- gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz
"paloma/c4_en":
validation_urls:
- gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz
"paloma/dolma-v1_5":
validation_urls:
- gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz
"paloma/dolma_100_programing_languages":
validation_urls:
- gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz
"paloma/dolma_100_subreddits":
validation_urls:
- gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz
"paloma/falcon-refinedweb":
validation_urls:
- gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz
"paloma/gab":
validation_urls:
- gs://levanter-data/paloma/gab/val/val*.jsonl.gz
"paloma/m2d2_s2orc_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz
"paloma/m2d2_wikipedia_unsplit":
validation_urls:
- gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz
"paloma/manosphere_meta_sep":
validation_urls:
- gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz
"paloma/mc4":
validation_urls:
- gs://levanter-data/paloma/mc4/val/val*.jsonl.gz
"paloma/ptb":
validation_urls:
- gs://levanter-data/paloma/ptb/val/val*.jsonl.gz
"paloma/redpajama":
validation_urls:
- gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz
"paloma/twitterAAE_HELM_fixed":
validation_urls:
- gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz
"paloma/wikitext_103":
validation_urls:
- gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz
train_weights:
dclm: 1.0
paloma/4chan: 0.0
paloma/c4_100_domains: 0.0
paloma/c4_en: 0.0
paloma/dolma-v1_5: 0.0
paloma/dolma_100_programing_languages: 0.0
paloma/dolma_100_subreddits: 0.0
paloma/falcon-refinedweb: 0.0
paloma/gab: 0.0
paloma/m2d2_s2orc_unsplit: 0.0
paloma/m2d2_wikipedia_unsplit: 0.0
paloma/manosphere_meta_sep: 0.0
paloma/mc4: 0.0
paloma/ptb: 0.0
paloma/redpajama: 0.0
paloma/twitterAAE_HELM_fixed: 0.0
paloma/wikitext_103: 0.0
44 changes: 22 additions & 22 deletions config/data/dolma_olmo_paloma.yaml
Original file line number Diff line number Diff line change
@@ -1,59 +1,59 @@
cache_dir: "gs://marin-data/tokenized/OLMo-1B/dolma-v1.7"
cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/dolma/v1.7"
tokenizer: "allenai/OLMo-1B" # requires `pip install ai2-olmo`
# tokenizer: "meta-llama/Llama-2-7b-hf"
stop_strategy: restart
configs:
dolma-algebraic-stack:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/algebraic-stack-train-{0000..0015}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/algebraic-stack-train-{0000..0015}.json.gz
dolma-arxiv:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/arxiv-{0000..0099}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/arxiv-{0000..0099}.json.gz
dolma-gutenberg:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/books-{0000..0002}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/books-{0000..0002}.json.gz
dolma-c4:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/c4-{0000..0170}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/c4-{0000..0170}.json.gz
dolma-cc:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-data/raw/dolma/dolma-v1.7/cc_en_tail-{0154..0444}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_head-{0000..0274}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0000..0238}.json.gz # 239 is missing
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0240..0379}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0000..0152}.json.gz # 153 is missing
- gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0154..0444}.json.gz
dolma-cc-news:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-data/raw/dolma/dolma-v1.7/cc_news_tail-0000.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_head-{0000..0004}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_middle-{0000..0002}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/cc_news_tail-0000.json.gz
dolma-falcon:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/falcon-{0000..0499}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/falcon-{0000..0499}.json.gz
dolma-megawika:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/megawika-{0000..0261}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/megawika-{0000..0261}.json.gz
dolma-owmath:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/open-web-math-train-{0000..0012}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/open-web-math-train-{0000..0012}.json.gz
dolma-pes2o:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/pes2o-{0000..0025}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/pes2o-{0000..0025}.json.gz
dolma-reddit:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/reddit-{0000..0077}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/reddit-{0000..0077}.json.gz
dolma-stackexchange:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/stackexchange-{0000..0025}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/stackexchange-{0000..0025}.json.gz
dolma-starcoder:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/starcoder-{0000..0048}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/starcoder-{0000..0048}.json.gz
dolma-flan:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/tulu_flan-{0000..0065}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/tulu_flan-{0000..0065}.json.gz
dolma-wiki:
train_urls:
- gs://marin-data/raw/dolma/dolma-v1.7/wiki-{0000..0001}.json.gz
- gs://marin-us-central2/raw/dolma/v1.7/wiki-{0000..0001}.json.gz
# these are just for eval
"paloma/4chan":
validation_urls:
Expand Down
33 changes: 33 additions & 0 deletions config/llama_7b_with_dclm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
data: !include data/dclm_gpt_neo.yaml
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
trainer:
tracker:
type: wandb
entity: "stanford-mercury"
project: "marin"
tags: ["dclm", "7B", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 70000 # 280B / 4M
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4e-4
weight_decay: 0.1
min_lr_ratio: 0.1
beta1: 0.9
beta2: 0.95
warmup: 5000

z_loss_weight: 5e-6
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"braceexpand>=0.1.7",
"jmp>=0.0.3",
"fsspec[http]>=2024.2,<2024.10",
"tensorstore==0.1.63",
"tensorstore>=0.1.65",
"pytimeparse>=1.1.8",
"humanfriendly==10.0",
"safetensors[numpy]~=0.4.2",
Expand All @@ -50,7 +50,8 @@ dependencies = [
"filelock~=3.13",
# "ai2-olmo",
"async-lru~=2.0",
"tqdm-loggable>=0.2"
"tqdm-loggable>=0.2",
"deepdiff"
]

[project.urls]
Expand Down
16 changes: 10 additions & 6 deletions src/levanter/data/_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ class BatchProcessor(Generic[T_contra, U], ABC):
@abstractmethod
def __call__(self, batch: Sequence[T_contra]) -> Sequence[U] | U: # U can be batched "structure of arrays" form
"""
Process a batch of data. You should return either a RecordBatch, a sequence of dicts (one per output
Process a batch of data. You should return a sequence of dicts (one per output
example), or a dict of sequences (one per output field).
(We allow Mapping so that you can just return HF's BatchEncoding if you want.)
"""
raise NotImplementedError

Expand Down Expand Up @@ -58,8 +56,10 @@ def num_gpus(self) -> int:
return 0

@property
def batch_size(self) -> int:
return 128
@abstractmethod
def metadata(self) -> Dict[str, Any]:
"""Any metadata that changes the behavior of this processor."""
raise NotImplementedError


class _DatasetTransform(ABC):
Expand Down Expand Up @@ -150,7 +150,7 @@ def rec(dataset):


class _CompositeBatchProcessor(BatchProcessor):
def __init__(self, transforms, batch_size, num_cpus, num_gpus, resources):
def __init__(self, transforms, num_cpus, num_gpus, resources):
self.transforms = transforms
self._num_cpus = num_cpus
self._num_gpus = num_gpus
Expand Down Expand Up @@ -207,6 +207,10 @@ def __call__(self, batch):

return batch

@property
def metadata(self):
return {}


def dict_from_record_batch(b) -> dict:
# we follow the convention from hf batchencoding where homogeneous-lengthed arrays are turned into nd arrays
Expand Down
Loading

0 comments on commit b41838f

Please sign in to comment.