Skip to content

Commit

Permalink
Adding supervised data config (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 authored Oct 9, 2024
2 parents 36b29fd + cf2c9e5 commit 8bed0aa
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 20 deletions.
40 changes: 40 additions & 0 deletions config/gpt2_small_fast_supervised.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
data:
configs:
owt:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
wikitext:
id: dlwh/wikitext_103_detokenized
train_weights:
owt: 0.6
wikitext: 0.4
tokenizer: gpt2
cache_dir: "gs://levanter-data/tokenized/data_mix"
supervised_data:
validation_urls:
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz"
cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/"
model:
type: gpt2
hidden_dim: 768
num_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
tracker:
project: "levanter"
tags: [ "openwebtext+wiki", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
6 changes: 6 additions & 0 deletions scripts/launch_gpt2_small_fast_supervised_tpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Launches the "gpt_small_fast" model on a TPU node

python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \
python -m levanter.main.train_lm \
--config_path config/gpt2_small_fast_supervised.yaml \
--trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $*
4 changes: 3 additions & 1 deletion src/levanter/data/_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ class _BatchMapTransform(_DatasetTransform):
num_cpus: int
num_gpus: int
resources: dict
output_exemplar: Any

def __init__(self, fn, batch_size, num_cpus, num_gpus, resources):
def __init__(self, fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=None):
self.fn = fn
self.batch_size = batch_size
self.num_cpus = num_cpus
self.num_gpus = num_gpus
self.resources = resources
self.output_exemplar = output_exemplar


def as_record_batch(doc: BatchResult) -> pa.RecordBatch:
Expand Down
18 changes: 15 additions & 3 deletions src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def map(self, fn: Callable[[T_co], U]) -> "ShardedDataSource[U]":
return _MappedShardedDataSource(self, fn)

def map_batches(
self, fn: Callable[[list[T_co]], BatchResult], batch_size, *, num_cpus=1, num_gpus=0, **resources
self,
fn: Callable[[list[T_co]], BatchResult],
batch_size,
*,
num_cpus=1,
num_gpus=0,
output_exemplar=None,
**resources,
) -> "ShardedDataSource[dict]":
"""
**Lazily** map a function over batches of data. This is useful for doing things like batching data for a model,
Expand All @@ -131,7 +138,9 @@ def map_batches(
Returns:
A new ShardedDataset.
"""
return _BatchMappedShardedDataSource(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, **resources)
return _BatchMappedShardedDataSource(
self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, output_exemplar=output_exemplar, **resources
)


def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]:
Expand Down Expand Up @@ -478,10 +487,13 @@ def __init__(
batch_size,
num_cpus=1,
num_gpus=0,
output_exemplar=None,
**resources,
):
self.source = source
self._transform = _BatchMapTransform(fn, batch_size, num_cpus, num_gpus, resources)
self._transform = _BatchMapTransform(
fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=output_exemplar
)

@property
def shard_names(self) -> Sequence[str]:
Expand Down
90 changes: 76 additions & 14 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from itertools import chain
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union

import braceexpand

import datasets
import equinox as eqx
import fsspec
Expand Down Expand Up @@ -38,6 +38,7 @@
from levanter.store.jagged_array import JaggedArrayStore
from levanter.store.tree_store import TreeStore
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer
from levanter.utils.fsspec_utils import fsspec_expand_glob


silence_transformer_nag() # noqa
Expand Down Expand Up @@ -509,19 +510,7 @@ def urls_for_split(self, split):
else:
raise ValueError(f"Unknown split {split}")

def fsspec_expand_glob(url):
if "*" in url:
fs = fsspec.core.url_to_fs(url)[0]
globbed = fs.glob(url)
# have to append the fs prefix back on
protocol, _ = fsspec.core.split_protocol(url)
if protocol is None:
return globbed
return [f"{protocol}://{path}" for path in globbed]
else:
return [url]

urls = [globbed for pat in urls for url in braceexpand.braceexpand(pat) for globbed in fsspec_expand_glob(url)]
urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)]
return urls


Expand Down Expand Up @@ -573,6 +562,79 @@ def tagged_eval_sets(
return [(eval_sets[name], tags[name]) for name in eval_sets]


@dataclass
class LMSupervisedDatasetConfig:
"""This class represents a dataset source with URLs or hf name/id."""

cache_dir: str = "cache/"

tags: Optional[List[str]] = None
"""tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well"""
name: Optional[str] = None # name for hf dataset

validation_urls: List[str] = () # type:ignore


def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase):
sources = [example["input"] for example in batch]

targets = [f"{example['output']}" for example in batch]
# TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how alpaca does it.
examples = [s + t for s, t in zip(sources, targets)]
sources_tokenized = tokenizer(sources, padding=False, truncation=True)
examples_tokenized = tokenizer(examples, padding=False, truncation=True)

source_lens = [len(s) for s in sources_tokenized["input_ids"]]

return {
"input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]],
"sources_len": np.array(source_lens, dtype=np.int32),
}


def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample:
"""
Prepare an example for training. This function converts the (cached) batch encoding into an LmExample.
It goes through the following steps:
1. Pad the batch to the maximum length.
2. Mask out the input and prompt if requested.
3. Create an LmExample with the input_ids as the input and the next token as the target.
"""
with local_cpu_mesh():
# annoyingly, pad expects things to be batched so we have to prepend a batch axis
ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length")
ex = {k: v[0] for k, v in ex.items()}
input_ids = hax.named(ex["input_ids"], "position")
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
return lm_ex


def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase):
import levanter.data

validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)]
dataset = levanter.data.datasource_from_jsonl(validation_urls)

output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)}

dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore
dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""
Expand Down
20 changes: 18 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from levanter import callbacks
from levanter.checkpoint import load_checkpoint
from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig
from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig
from levanter.models.gpt2 import Gpt2Config
from levanter.models.lm_model import LmConfig, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
Expand All @@ -30,6 +30,7 @@
@dataclass
class TrainLmConfig:
data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig)
supervised_data: Optional[LMSupervisedDatasetConfig] = None
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=Gpt2Config)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)
Expand Down Expand Up @@ -170,7 +171,6 @@ def main(config: TrainLmConfig):
(CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags)
for ds, tags in tagged_eval_datasets
]

cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
causal_datasets,
Expand All @@ -182,6 +182,22 @@ def main(config: TrainLmConfig):
)
trainer.add_hook(cb, every=config.trainer.steps_per_eval)

if config.supervised_data is not None:
logger.info("Using supervised data")
supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")]
# TODO Add tags
cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
supervised_eval,
tokenizer,
trainer.device_mesh,
compute_axis_mapping,
max_eval_examples_per_ds,
prefix="internal_eval",
mp=config.trainer.mp,
)
trainer.add_hook(cb, every=config.trainer.steps_per_eval)

flops_per_token = config.model.flops_per_token(vocab_size)
flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
trainer.add_hook(
Expand Down
17 changes: 17 additions & 0 deletions src/levanter/utils/fsspec_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import fsspec
import braceexpand


def exists(url, **kwargs) -> bool:
Expand All @@ -11,3 +12,19 @@ def mkdirs(path):
"""Create a directory and any necessary parent directories."""
fs, path = fsspec.core.url_to_fs(path)
fs.makedirs(path, exist_ok=True)


def fsspec_expand_glob(url):
expanded_urls = braceexpand.braceexpand(url)
for expanded_url in expanded_urls:
if "*" in expanded_url:
fs = fsspec.core.url_to_fs(expanded_url)[0]
globbed = fs.glob(expanded_url)
# have to append the fs prefix back on
protocol, _ = fsspec.core.split_protocol(expanded_url)
if protocol is None:
yield from globbed
else:
yield from [f"{protocol}://{path}" for path in globbed]
else:
yield expanded_url
82 changes: 82 additions & 0 deletions tests/test_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
from transformers import AutoTokenizer

import haliax

from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example


def test_supervised_eval():
examples = [
{
"input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:",
"output": "B",
}
]
tokenizer = AutoTokenizer.from_pretrained("gpt2")

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

output = preprocess_supervised_example(examples, tokenizer)
assert len(output["input_ids"][0]) == output["sources_len"][0] + 1

ex = {
"input_ids": np.array(
[
16742,
477,
269,
287,
1168,
62,
18,
884,
326,
1168,
62,
18,
58,
87,
60,
29006,
87,
61,
17,
1343,
269,
8,
318,
257,
2214,
13,
198,
32,
13,
657,
198,
33,
13,
352,
198,
34,
13,
362,
198,
35,
13,
513,
198,
33706,
25,
33,
],
dtype=np.int32,
),
"sources_len": np.array(45, dtype=np.int32),
}

lm_ex = _prepare_supervised_example(ex, tokenizer)

assert lm_ex.loss_mask["position", 44]
assert haliax.sum(lm_ex.loss_mask) == 1

0 comments on commit 8bed0aa

Please sign in to comment.