diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml new file mode 100644 index 000000000..56ce7ea36 --- /dev/null +++ b/config/gpt2_small_fast_supervised.yaml @@ -0,0 +1,40 @@ +data: + configs: + owt: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + wikitext: + id: dlwh/wikitext_103_detokenized + train_weights: + owt: 0.6 + wikitext: 0.4 + tokenizer: gpt2 + cache_dir: "gs://levanter-data/tokenized/data_mix" +supervised_data: + validation_urls: + - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" + cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "openwebtext+wiki", "gpt2", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/scripts/launch_gpt2_small_fast_supervised_tpu.sh b/scripts/launch_gpt2_small_fast_supervised_tpu.sh new file mode 100644 index 000000000..df38aec99 --- /dev/null +++ b/scripts/launch_gpt2_small_fast_supervised_tpu.sh @@ -0,0 +1,6 @@ +# Launches the "gpt_small_fast" model on a TPU node + +python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ + python -m levanter.main.train_lm \ + --config_path config/gpt2_small_fast_supervised.yaml \ + --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 09efb364d..3c1f77494 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -79,13 +79,15 @@ class _BatchMapTransform(_DatasetTransform): num_cpus: int num_gpus: int resources: dict + output_exemplar: Any - def __init__(self, fn, batch_size, num_cpus, num_gpus, resources): + def __init__(self, fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=None): self.fn = fn self.batch_size = batch_size self.num_cpus = num_cpus self.num_gpus = num_gpus self.resources = resources + self.output_exemplar = output_exemplar def as_record_batch(doc: BatchResult) -> pa.RecordBatch: diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 38682616d..6ebb15cc3 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -113,7 +113,14 @@ def map(self, fn: Callable[[T_co], U]) -> "ShardedDataSource[U]": return _MappedShardedDataSource(self, fn) def map_batches( - self, fn: Callable[[list[T_co]], BatchResult], batch_size, *, num_cpus=1, num_gpus=0, **resources + self, + fn: Callable[[list[T_co]], BatchResult], + batch_size, + *, + num_cpus=1, + num_gpus=0, + output_exemplar=None, + **resources, ) -> "ShardedDataSource[dict]": """ **Lazily** map a function over batches of data. This is useful for doing things like batching data for a model, @@ -131,7 +138,9 @@ def map_batches( Returns: A new ShardedDataset. """ - return _BatchMappedShardedDataSource(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, **resources) + return _BatchMappedShardedDataSource( + self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, output_exemplar=output_exemplar, **resources + ) def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: @@ -478,10 +487,13 @@ def __init__( batch_size, num_cpus=1, num_gpus=0, + output_exemplar=None, **resources, ): self.source = source - self._transform = _BatchMapTransform(fn, batch_size, num_cpus, num_gpus, resources) + self._transform = _BatchMapTransform( + fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=output_exemplar + ) @property def shard_names(self) -> Sequence[str]: diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index bcfcad397..dfd16f844 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -10,7 +10,7 @@ from itertools import chain from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union -import braceexpand + import datasets import equinox as eqx import fsspec @@ -38,6 +38,7 @@ from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore from levanter.utils.hf_utils import num_cpus_used_by_tokenizer +from levanter.utils.fsspec_utils import fsspec_expand_glob silence_transformer_nag() # noqa @@ -509,19 +510,7 @@ def urls_for_split(self, split): else: raise ValueError(f"Unknown split {split}") - def fsspec_expand_glob(url): - if "*" in url: - fs = fsspec.core.url_to_fs(url)[0] - globbed = fs.glob(url) - # have to append the fs prefix back on - protocol, _ = fsspec.core.split_protocol(url) - if protocol is None: - return globbed - return [f"{protocol}://{path}" for path in globbed] - else: - return [url] - - urls = [globbed for pat in urls for url in braceexpand.braceexpand(pat) for globbed in fsspec_expand_glob(url)] + urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)] return urls @@ -573,6 +562,79 @@ def tagged_eval_sets( return [(eval_sets[name], tags[name]) for name in eval_sets] +@dataclass +class LMSupervisedDatasetConfig: + """This class represents a dataset source with URLs or hf name/id.""" + + cache_dir: str = "cache/" + + tags: Optional[List[str]] = None + """tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well""" + name: Optional[str] = None # name for hf dataset + + validation_urls: List[str] = () # type:ignore + + +def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): + sources = [example["input"] for example in batch] + + targets = [f"{example['output']}" for example in batch] + # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how alpaca does it. + examples = [s + t for s, t in zip(sources, targets)] + sources_tokenized = tokenizer(sources, padding=False, truncation=True) + examples_tokenized = tokenizer(examples, padding=False, truncation=True) + + source_lens = [len(s) for s in sources_tokenized["input_ids"]] + + return { + "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], + "sources_len": np.array(source_lens, dtype=np.int32), + } + + +def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample: + """ + Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. + + It goes through the following steps: + + 1. Pad the batch to the maximum length. + 2. Mask out the input and prompt if requested. + 3. Create an LmExample with the input_ids as the input and the next token as the target. + """ + with local_cpu_mesh(): + # annoyingly, pad expects things to be batched so we have to prepend a batch axis + ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length") + ex = {k: v[0] for k, v in ex.items()} + input_ids = hax.named(ex["input_ids"], "position") + # mask out padding and anything before the start of the target + Pos = input_ids.resolve_axis("position") + loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 + + # don't predict the padding + targets = hax.roll(input_ids, -1, Pos) + loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + return lm_ex + + +def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): + import levanter.data + + validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + dataset = levanter.data.datasource_from_jsonl(validation_urls) + + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} + + dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore + dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + + @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index c8316090a..fe5e5dd35 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -16,7 +16,7 @@ from levanter import callbacks from levanter.checkpoint import load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback -from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig +from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig @@ -30,6 +30,7 @@ @dataclass class TrainLmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) + supervised_data: Optional[LMSupervisedDatasetConfig] = None trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) optimizer: OptimizerConfig = field(default_factory=AdamConfig) @@ -170,7 +171,6 @@ def main(config: TrainLmConfig): (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets ] - cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, causal_datasets, @@ -182,6 +182,22 @@ def main(config: TrainLmConfig): ) trainer.add_hook(cb, every=config.trainer.steps_per_eval) + if config.supervised_data is not None: + logger.info("Using supervised data") + supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")] + # TODO Add tags + cb = levanter.eval.cb_tagged_lm_evaluate( + EvalBatch, + supervised_eval, + tokenizer, + trainer.device_mesh, + compute_axis_mapping, + max_eval_examples_per_ds, + prefix="internal_eval", + mp=config.trainer.mp, + ) + trainer.add_hook(cb, every=config.trainer.steps_per_eval) + flops_per_token = config.model.flops_per_token(vocab_size) flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None trainer.add_hook( diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index 896ea8450..452ab3d84 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,4 +1,5 @@ import fsspec +import braceexpand def exists(url, **kwargs) -> bool: @@ -11,3 +12,19 @@ def mkdirs(path): """Create a directory and any necessary parent directories.""" fs, path = fsspec.core.url_to_fs(path) fs.makedirs(path, exist_ok=True) + + +def fsspec_expand_glob(url): + expanded_urls = braceexpand.braceexpand(url) + for expanded_url in expanded_urls: + if "*" in expanded_url: + fs = fsspec.core.url_to_fs(expanded_url)[0] + globbed = fs.glob(expanded_url) + # have to append the fs prefix back on + protocol, _ = fsspec.core.split_protocol(expanded_url) + if protocol is None: + yield from globbed + else: + yield from [f"{protocol}://{path}" for path in globbed] + else: + yield expanded_url diff --git a/tests/test_supervised.py b/tests/test_supervised.py new file mode 100644 index 000000000..e1d9098d2 --- /dev/null +++ b/tests/test_supervised.py @@ -0,0 +1,82 @@ +import numpy as np +from transformers import AutoTokenizer + +import haliax + +from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example + + +def test_supervised_eval(): + examples = [ + { + "input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:", + "output": "B", + } + ] + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + output = preprocess_supervised_example(examples, tokenizer) + assert len(output["input_ids"][0]) == output["sources_len"][0] + 1 + + ex = { + "input_ids": np.array( + [ + 16742, + 477, + 269, + 287, + 1168, + 62, + 18, + 884, + 326, + 1168, + 62, + 18, + 58, + 87, + 60, + 29006, + 87, + 61, + 17, + 1343, + 269, + 8, + 318, + 257, + 2214, + 13, + 198, + 32, + 13, + 657, + 198, + 33, + 13, + 352, + 198, + 34, + 13, + 362, + 198, + 35, + 13, + 513, + 198, + 33706, + 25, + 33, + ], + dtype=np.int32, + ), + "sources_len": np.array(45, dtype=np.int32), + } + + lm_ex = _prepare_supervised_example(ex, tokenizer) + + assert lm_ex.loss_mask["position", 44] + assert haliax.sum(lm_ex.loss_mask) == 1