From 43268e0ed60ee4b17ca91ff5fca1504375d95f4a Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Wed, 25 Sep 2024 16:20:07 -0700 Subject: [PATCH 1/8] Adding supervised data config --- config/gpt2_small_fast_supervised.yaml | 40 +++++++++++ scripts/launch_gpt2_small_fast_tpu.sh | 3 +- src/levanter/data/_preprocessor.py | 4 +- src/levanter/data/sharded_datasource.py | 7 +- src/levanter/data/text.py | 94 ++++++++++++++++++++++++- src/levanter/main/train_lm.py | 20 +++++- 6 files changed, 160 insertions(+), 8 deletions(-) create mode 100644 config/gpt2_small_fast_supervised.yaml diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml new file mode 100644 index 000000000..0181a3fd4 --- /dev/null +++ b/config/gpt2_small_fast_supervised.yaml @@ -0,0 +1,40 @@ +data: + configs: + owt: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + wikitext: + id: dlwh/wikitext_103_detokenized + train_weights: + owt: 0.6 + wikitext: 0.4 + tokenizer: gpt2 + cache_dir: "gs://levanter-data/tokenized/data_mix" +supervised_data: + validation_urls: + - "gs://marin-us-central2/benchmarks/mmlu/mmlu-abstract_algebra-dev-evaluation.jsonl.gz" + cache_dir: "gs://marin-us-central2/benchmarks/tokenized/mmlu/" +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + project: "levanter" + tags: [ "openwebtext+wiki", "gpt2", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/scripts/launch_gpt2_small_fast_tpu.sh b/scripts/launch_gpt2_small_fast_tpu.sh index 7b2634749..342439041 100644 --- a/scripts/launch_gpt2_small_fast_tpu.sh +++ b/scripts/launch_gpt2_small_fast_tpu.sh @@ -2,5 +2,6 @@ python infra/launch.py --foreground --tpu_name levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ python -m levanter.main.train_lm \ - --config_path config/gpt2_small_fast.yaml \ + --config_path config/gpt2_small_fast_supervised.yaml \ --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* + \ No newline at end of file diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 9ee1e2dc2..284243ec8 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -79,13 +79,15 @@ class _BatchMapTransform(_DatasetTransform): num_cpus: int num_gpus: int resources: dict + output_exemplar: Any - def __init__(self, fn, batch_size, num_cpus, num_gpus, resources): + def __init__(self, fn, batch_size, num_cpus, num_gpus, resources, output_exemplar = None): self.fn = fn self.batch_size = batch_size self.num_cpus = num_cpus self.num_gpus = num_gpus self.resources = resources + self.output_exemplar = output_exemplar def as_record_batch(doc: BatchResult) -> pa.RecordBatch: diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 38682616d..74e0c8f3a 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -113,7 +113,7 @@ def map(self, fn: Callable[[T_co], U]) -> "ShardedDataSource[U]": return _MappedShardedDataSource(self, fn) def map_batches( - self, fn: Callable[[list[T_co]], BatchResult], batch_size, *, num_cpus=1, num_gpus=0, **resources + self, fn: Callable[[list[T_co]], BatchResult], batch_size, *, num_cpus=1, num_gpus=0, output_exemplar=None, **resources ) -> "ShardedDataSource[dict]": """ **Lazily** map a function over batches of data. This is useful for doing things like batching data for a model, @@ -131,7 +131,7 @@ def map_batches( Returns: A new ShardedDataset. """ - return _BatchMappedShardedDataSource(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, **resources) + return _BatchMappedShardedDataSource(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, output_exemplar=output_exemplar, **resources) def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: @@ -478,10 +478,11 @@ def __init__( batch_size, num_cpus=1, num_gpus=0, + output_exemplar=None, **resources, ): self.source = source - self._transform = _BatchMapTransform(fn, batch_size, num_cpus, num_gpus, resources) + self._transform = _BatchMapTransform(fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=output_exemplar) @property def shard_names(self) -> Sequence[str]: diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 20a11d090..feadd692d 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -583,6 +583,98 @@ def tagged_eval_sets( return [(eval_sets[name], tags[name]) for name in eval_sets] +@dataclass +class LMSupervisedDatasetConfig(LMDatasetSourceConfig): + """This class represents a dataset source with URLs or hf name/id.""" + + cache_dir: str = "cache/" + + tags: Optional[List[str]] = None + """tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well""" + name: Optional[str] = None # name for hf dataset + + validation_urls: List[str] = () # type:ignore + + def token_seq_dataset( + self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Optional[TokenSeqDataset]: + cache = self.build_or_load_cache(split, monitors=monitors) + if cache is None: + return None + return TokenSeqDataset(cache, seq_len) + + def validation_set( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Optional[TokenSeqDataset]: + return self.token_seq_dataset("validation", seq_len, monitors) + + def validation_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, AsyncDataset[np.ndarray]]: + validation_set = self.validation_set(seq_len, monitors) + if validation_set is not None: + return {"": validation_set} + else: + return {} + + # Add tagged eval set with split for auxiliary and validation dataset + +def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): + import levanter.data + dataset = levanter.data.datasource_from_jsonl(config.validation_urls) + + def preprocess(batch): + sources = [example["input"] for example in batch] + targets = [f"{example['output']}{tokenizer.eos_token}" for example in batch] + # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it. + examples = [s + t for s, t in zip(sources, targets)] + sources_tokenized = tokenizer(sources, padding=False, truncation=True) + examples_tokenized = tokenizer(examples, padding=False, truncation=True) + + source_lens = [len(s) for s in sources_tokenized["input_ids"]] + + return { + "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], + "sources_len": np.array(source_lens, dtype=np.int32), + } + + output_exemplar = { + "input_ids": np.zeros((0,), dtype=np.int32), + "sources_len": np.zeros((), dtype=np.int32) + } + + dataset = dataset.map_batches(preprocess, batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore + dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore + + def _prepare_example(ex: dict) -> LmExample: + """ + Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. + + It goes through the following steps: + + 1. Pad the batch to the maximum length. + 2. Mask out the input and prompt if requested. + 3. Create an LmExample with the input_ids as the input and the next token as the target. + """ + # annoyingly, pad expects things to be batched so we have to prepend a batch axis + tokenizer.pad_token = tokenizer.eos_token + ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length") + ex = {k: v[0] for k, v in ex.items()} + input_ids = hax.named(ex["input_ids"], "position") + # mask out padding and anything before the start of the target + Pos = input_ids.resolve_axis("position") + # if config.mask_inputs: + # loss_mask = hax.arange(Pos) >= ex["sources_len"] + + # # don't predict the padding + # targets = hax.roll(input_ids, -1, Pos) + # loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + # else: + loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + return lm_ex + + return dataset.map(_prepare_example) @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): @@ -828,4 +920,4 @@ def build_caches( @property def sources(self) -> dict[str, LMDatasetSourceConfig]: - return self.configs + return self.configs \ No newline at end of file diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 6c96f8b62..3166c91c9 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -16,7 +16,7 @@ from levanter import callbacks from levanter.checkpoint import load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback -from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig +from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig @@ -30,6 +30,7 @@ @dataclass class TrainLmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) + supervised_data: Optional[LMSupervisedDatasetConfig] = None trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) optimizer: OptimizerConfig = field(default_factory=AdamConfig) @@ -170,7 +171,6 @@ def main(config: TrainLmConfig): (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets ] - cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, causal_datasets, @@ -182,6 +182,22 @@ def main(config: TrainLmConfig): ) trainer.add_hook(cb, every=config.trainer.steps_per_eval) + if config.supervised_data is not None: + logger.info("Using supervised data") + supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")] + # TODO Add tags + cb = levanter.eval.cb_tagged_lm_evaluate( + EvalBatch, + supervised_eval, + tokenizer, + trainer.device_mesh, + compute_axis_mapping, + max_eval_examples_per_ds, + prefix="internal_eval", + mp=config.trainer.mp, + ) + trainer.add_hook(cb, every=config.trainer.steps_per_eval) + flops_per_token = config.model.flops_per_token(vocab_size) flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None trainer.add_hook( From d6ad71fc8f0eb1cd4325122e163820bdc60cf96d Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Wed, 25 Sep 2024 16:24:22 -0700 Subject: [PATCH 2/8] Fixing linter error --- scripts/launch_gpt2_small_fast_tpu.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/launch_gpt2_small_fast_tpu.sh b/scripts/launch_gpt2_small_fast_tpu.sh index 342439041..437491e01 100644 --- a/scripts/launch_gpt2_small_fast_tpu.sh +++ b/scripts/launch_gpt2_small_fast_tpu.sh @@ -4,4 +4,3 @@ python infra/launch.py --foreground --tpu_name levanter-itest-32 --zone us-centr python -m levanter.main.train_lm \ --config_path config/gpt2_small_fast_supervised.yaml \ --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* - \ No newline at end of file From f5b32cd3eed90227eda40de8ef4f53dbe6785b66 Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Thu, 26 Sep 2024 23:23:05 -0700 Subject: [PATCH 3/8] Fixing supervised training --- config/gpt2_small_fast_supervised.yaml | 4 +- src/levanter/data/text.py | 112 +++++++++++++------------ tests/test_supervised.py | 26 ++++++ 3 files changed, 86 insertions(+), 56 deletions(-) create mode 100644 tests/test_supervised.py diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml index 0181a3fd4..56ce7ea36 100644 --- a/config/gpt2_small_fast_supervised.yaml +++ b/config/gpt2_small_fast_supervised.yaml @@ -14,8 +14,8 @@ data: cache_dir: "gs://levanter-data/tokenized/data_mix" supervised_data: validation_urls: - - "gs://marin-us-central2/benchmarks/mmlu/mmlu-abstract_algebra-dev-evaluation.jsonl.gz" - cache_dir: "gs://marin-us-central2/benchmarks/tokenized/mmlu/" + - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" + cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" model: type: gpt2 hidden_dim: 768 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index feadd692d..5a3dbce57 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -389,6 +389,20 @@ def num_gpus(self) -> int: def batch_size(self) -> int: return self._batch_size +def fsspec_expand_glob(url): + expanded_urls = braceexpand.braceexpand(url) + for expanded_url in expanded_urls: + if "*" in expanded_url: + fs = fsspec.core.url_to_fs(expanded_url)[0] + globbed = fs.glob(expanded_url) + # have to append the fs prefix back on + protocol, _ = fsspec.core.split_protocol(expanded_url) + if protocol is None: + yield from globbed + else: + yield from [f"{protocol}://{path}" for path in globbed] + else: + yield expanded_url def concatenate_and_group_texts( encoding: BatchEncoding, @@ -520,19 +534,7 @@ def urls_for_split(self, split): else: raise ValueError(f"Unknown split {split}") - def fsspec_expand_glob(url): - if "*" in url: - fs = fsspec.core.url_to_fs(url)[0] - globbed = fs.glob(url) - # have to append the fs prefix back on - protocol, _ = fsspec.core.split_protocol(url) - if protocol is None: - return globbed - return [f"{protocol}://{path}" for path in globbed] - else: - return [url] - - urls = [globbed for pat in urls for url in braceexpand.braceexpand(pat) for globbed in fsspec_expand_glob(url)] + urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)] return urls @@ -619,62 +621,64 @@ def validation_sets( # Add tagged eval set with split for auxiliary and validation dataset -def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): - import levanter.data - dataset = levanter.data.datasource_from_jsonl(config.validation_urls) - - def preprocess(batch): - sources = [example["input"] for example in batch] - targets = [f"{example['output']}{tokenizer.eos_token}" for example in batch] - # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it. - examples = [s + t for s, t in zip(sources, targets)] - sources_tokenized = tokenizer(sources, padding=False, truncation=True) - examples_tokenized = tokenizer(examples, padding=False, truncation=True) +def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): + sources = [example["input"] for example in batch] + + targets = [f"{example['output']}" for example in batch] + # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it. + examples = [s + t for s, t in zip(sources, targets)] + sources_tokenized = tokenizer(sources, padding=False, truncation=True) + examples_tokenized = tokenizer(examples, padding=False, truncation=True) - source_lens = [len(s) for s in sources_tokenized["input_ids"]] + source_lens = [len(s) for s in sources_tokenized["input_ids"]] - return { - "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], - "sources_len": np.array(source_lens, dtype=np.int32), - } - - output_exemplar = { - "input_ids": np.zeros((0,), dtype=np.int32), - "sources_len": np.zeros((), dtype=np.int32) + return { + "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], + "sources_len": np.array(source_lens, dtype=np.int32), } - dataset = dataset.map_batches(preprocess, batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore - - def _prepare_example(ex: dict) -> LmExample: - """ - Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. +def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample: + """ + Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. - It goes through the following steps: + It goes through the following steps: - 1. Pad the batch to the maximum length. - 2. Mask out the input and prompt if requested. - 3. Create an LmExample with the input_ids as the input and the next token as the target. - """ + 1. Pad the batch to the maximum length. + 2. Mask out the input and prompt if requested. + 3. Create an LmExample with the input_ids as the input and the next token as the target. + """ + with local_cpu_mesh(): # annoyingly, pad expects things to be batched so we have to prepend a batch axis - tokenizer.pad_token = tokenizer.eos_token ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length") ex = {k: v[0] for k, v in ex.items()} input_ids = hax.named(ex["input_ids"], "position") # mask out padding and anything before the start of the target Pos = input_ids.resolve_axis("position") - # if config.mask_inputs: - # loss_mask = hax.arange(Pos) >= ex["sources_len"] - - # # don't predict the padding - # targets = hax.roll(input_ids, -1, Pos) - # loss_mask = loss_mask & (targets != tokenizer.pad_token_id) - # else: - loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) + loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 + + # don't predict the padding + targets = hax.roll(input_ids, -1, Pos) + loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) return lm_ex - return dataset.map(_prepare_example) +def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): + import levanter.data + validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + dataset = levanter.data.datasource_from_jsonl(validation_urls) + + output_exemplar = { + "input_ids": np.zeros((0,), dtype=np.int32), + "sources_len": np.zeros((), dtype=np.int32) + } + + dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore + dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): diff --git a/tests/test_supervised.py b/tests/test_supervised.py new file mode 100644 index 000000000..48856d585 --- /dev/null +++ b/tests/test_supervised.py @@ -0,0 +1,26 @@ +from levanter.data.text import preprocess_supervised_example, _prepare_supervised_example +from transformers import AutoTokenizer +import numpy as np +import haliax + +def test_supervised_eval(): + examples = [{"input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:", "output": "B"}] + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + output = preprocess_supervised_example(examples, tokenizer) + assert len(output["input_ids"][0]) == output["sources_len"][0] + 1 + + ex = {'input_ids': np.array([16742, 477, 269, 287, 1168, 62, 18, 884, 326, + 1168, 62, 18, 58, 87, 60, 29006, 87, 61, + 17, 1343, 269, 8, 318, 257, 2214, 13, 198, + 32, 13, 657, 198, 33, 13, 352, 198, 34, + 13, 362, 198, 35, 13, 513, 198, 33706, 25, + 33], dtype=np.int32), 'sources_len': np.array(45, dtype=np.int32)} + + lm_ex = _prepare_supervised_example(ex, tokenizer) + + assert(lm_ex.loss_mask['position', 44] != False) + assert(haliax.sum(lm_ex.loss_mask) == 1) \ No newline at end of file From 6483b42a8b0acba0c62ca80d8e16eced3aae4b62 Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Thu, 26 Sep 2024 23:25:38 -0700 Subject: [PATCH 4/8] Making linter happy --- src/levanter/data/text.py | 4 ++-- tests/test_supervised.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 5a3dbce57..c89604488 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -623,7 +623,7 @@ def validation_sets( def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): sources = [example["input"] for example in batch] - + targets = [f"{example['output']}" for example in batch] # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it. examples = [s + t for s, t in zip(sources, targets)] @@ -655,7 +655,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> # mask out padding and anything before the start of the target Pos = input_ids.resolve_axis("position") loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 - + # don't predict the padding targets = hax.roll(input_ids, -1, Pos) loss_mask = loss_mask & (targets != tokenizer.pad_token_id) diff --git a/tests/test_supervised.py b/tests/test_supervised.py index 48856d585..49e38b4c4 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -6,21 +6,21 @@ def test_supervised_eval(): examples = [{"input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:", "output": "B"}] tokenizer = AutoTokenizer.from_pretrained("gpt2") - + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - + output = preprocess_supervised_example(examples, tokenizer) assert len(output["input_ids"][0]) == output["sources_len"][0] + 1 - + ex = {'input_ids': np.array([16742, 477, 269, 287, 1168, 62, 18, 884, 326, 1168, 62, 18, 58, 87, 60, 29006, 87, 61, 17, 1343, 269, 8, 318, 257, 2214, 13, 198, 32, 13, 657, 198, 33, 13, 352, 198, 34, 13, 362, 198, 35, 13, 513, 198, 33706, 25, 33], dtype=np.int32), 'sources_len': np.array(45, dtype=np.int32)} - + lm_ex = _prepare_supervised_example(ex, tokenizer) - + assert(lm_ex.loss_mask['position', 44] != False) assert(haliax.sum(lm_ex.loss_mask) == 1) \ No newline at end of file From 45d41d8d54712ede785dbecef77f48d4e7221e7f Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Thu, 26 Sep 2024 23:47:24 -0700 Subject: [PATCH 5/8] Making linter happy --- src/levanter/data/_preprocessor.py | 2 +- src/levanter/data/sharded_datasource.py | 17 +++++- src/levanter/data/text.py | 57 +++++++++--------- tests/test_supervised.py | 78 +++++++++++++++++++++---- 4 files changed, 113 insertions(+), 41 deletions(-) diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 284243ec8..170796fb6 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -81,7 +81,7 @@ class _BatchMapTransform(_DatasetTransform): resources: dict output_exemplar: Any - def __init__(self, fn, batch_size, num_cpus, num_gpus, resources, output_exemplar = None): + def __init__(self, fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=None): self.fn = fn self.batch_size = batch_size self.num_cpus = num_cpus diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 74e0c8f3a..6ebb15cc3 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -113,7 +113,14 @@ def map(self, fn: Callable[[T_co], U]) -> "ShardedDataSource[U]": return _MappedShardedDataSource(self, fn) def map_batches( - self, fn: Callable[[list[T_co]], BatchResult], batch_size, *, num_cpus=1, num_gpus=0, output_exemplar=None, **resources + self, + fn: Callable[[list[T_co]], BatchResult], + batch_size, + *, + num_cpus=1, + num_gpus=0, + output_exemplar=None, + **resources, ) -> "ShardedDataSource[dict]": """ **Lazily** map a function over batches of data. This is useful for doing things like batching data for a model, @@ -131,7 +138,9 @@ def map_batches( Returns: A new ShardedDataset. """ - return _BatchMappedShardedDataSource(self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, output_exemplar=output_exemplar, **resources) + return _BatchMappedShardedDataSource( + self, fn, batch_size, num_cpus=num_cpus, num_gpus=num_gpus, output_exemplar=output_exemplar, **resources + ) def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: @@ -482,7 +491,9 @@ def __init__( **resources, ): self.source = source - self._transform = _BatchMapTransform(fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=output_exemplar) + self._transform = _BatchMapTransform( + fn, batch_size, num_cpus, num_gpus, resources, output_exemplar=output_exemplar + ) @property def shard_names(self) -> Sequence[str]: diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c89604488..664f067dd 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -389,6 +389,7 @@ def num_gpus(self) -> int: def batch_size(self) -> int: return self._batch_size + def fsspec_expand_glob(url): expanded_urls = braceexpand.braceexpand(url) for expanded_url in expanded_urls: @@ -404,6 +405,7 @@ def fsspec_expand_glob(url): else: yield expanded_url + def concatenate_and_group_texts( encoding: BatchEncoding, seq_len: int, @@ -585,6 +587,7 @@ def tagged_eval_sets( return [(eval_sets[name], tags[name]) for name in eval_sets] + @dataclass class LMSupervisedDatasetConfig(LMDatasetSourceConfig): """This class represents a dataset source with URLs or hf name/id.""" @@ -597,30 +600,31 @@ class LMSupervisedDatasetConfig(LMDatasetSourceConfig): validation_urls: List[str] = () # type:ignore - def token_seq_dataset( - self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - cache = self.build_or_load_cache(split, monitors=monitors) - if cache is None: - return None - return TokenSeqDataset(cache, seq_len) - - def validation_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - return self.token_seq_dataset("validation", seq_len, monitors) - - def validation_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, AsyncDataset[np.ndarray]]: - validation_set = self.validation_set(seq_len, monitors) - if validation_set is not None: - return {"": validation_set} - else: - return {} + # def token_seq_dataset( + # self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + # ) -> Optional[TokenSeqDataset]: + # cache = self.build_or_load_cache(split, monitors=monitors) + # if cache is None: + # return None + # return TokenSeqDataset(cache, seq_len) + + # def validation_set( + # self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + # ) -> Optional[TokenSeqDataset]: + # return self.token_seq_dataset("validation", seq_len, monitors) + + # def validation_sets( + # self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + # ) -> Mapping[str, AsyncDataset[np.ndarray]]: + # validation_set = self.validation_set(seq_len, monitors) + # if validation_set is not None: + # return {"": validation_set} + # else: + # return {} # Add tagged eval set with split for auxiliary and validation dataset + def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): sources = [example["input"] for example in batch] @@ -637,6 +641,7 @@ def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): "sources_len": np.array(source_lens, dtype=np.int32), } + def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample: """ Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. @@ -663,15 +668,14 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) return lm_ex + def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): import levanter.data + validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] dataset = levanter.data.datasource_from_jsonl(validation_urls) - output_exemplar = { - "input_ids": np.zeros((0,), dtype=np.int32), - "sources_len": np.zeros((), dtype=np.int32) - } + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore @@ -680,6 +684,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" @@ -924,4 +929,4 @@ def build_caches( @property def sources(self) -> dict[str, LMDatasetSourceConfig]: - return self.configs \ No newline at end of file + return self.configs diff --git a/tests/test_supervised.py b/tests/test_supervised.py index 49e38b4c4..e1d9098d2 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -1,10 +1,18 @@ -from levanter.data.text import preprocess_supervised_example, _prepare_supervised_example -from transformers import AutoTokenizer import numpy as np +from transformers import AutoTokenizer + import haliax +from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example + + def test_supervised_eval(): - examples = [{"input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:", "output": "B"}] + examples = [ + { + "input": "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer:", + "output": "B", + } + ] tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token is None: @@ -13,14 +21,62 @@ def test_supervised_eval(): output = preprocess_supervised_example(examples, tokenizer) assert len(output["input_ids"][0]) == output["sources_len"][0] + 1 - ex = {'input_ids': np.array([16742, 477, 269, 287, 1168, 62, 18, 884, 326, - 1168, 62, 18, 58, 87, 60, 29006, 87, 61, - 17, 1343, 269, 8, 318, 257, 2214, 13, 198, - 32, 13, 657, 198, 33, 13, 352, 198, 34, - 13, 362, 198, 35, 13, 513, 198, 33706, 25, - 33], dtype=np.int32), 'sources_len': np.array(45, dtype=np.int32)} + ex = { + "input_ids": np.array( + [ + 16742, + 477, + 269, + 287, + 1168, + 62, + 18, + 884, + 326, + 1168, + 62, + 18, + 58, + 87, + 60, + 29006, + 87, + 61, + 17, + 1343, + 269, + 8, + 318, + 257, + 2214, + 13, + 198, + 32, + 13, + 657, + 198, + 33, + 13, + 352, + 198, + 34, + 13, + 362, + 198, + 35, + 13, + 513, + 198, + 33706, + 25, + 33, + ], + dtype=np.int32, + ), + "sources_len": np.array(45, dtype=np.int32), + } lm_ex = _prepare_supervised_example(ex, tokenizer) - assert(lm_ex.loss_mask['position', 44] != False) - assert(haliax.sum(lm_ex.loss_mask) == 1) \ No newline at end of file + assert lm_ex.loss_mask["position", 44] + assert haliax.sum(lm_ex.loss_mask) == 1 From 5370c72a9cfb1c75b07300aa614b7b038c6dfa6c Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 9 Oct 2024 16:57:09 -0400 Subject: [PATCH 6/8] Update src/levanter/data/text.py Co-authored-by: David Hall --- src/levanter/data/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 664f067dd..acc4ab778 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -629,7 +629,7 @@ def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): sources = [example["input"] for example in batch] targets = [f"{example['output']}" for example in batch] - # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how the original code does it. + # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how alpaca does it. examples = [s + t for s, t in zip(sources, targets)] sources_tokenized = tokenizer(sources, padding=False, truncation=True) examples_tokenized = tokenizer(examples, padding=False, truncation=True) From 2f625d3c952eb3e933a9834b8beef3a9bf9aafc0 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 9 Oct 2024 14:02:17 -0700 Subject: [PATCH 7/8] address david's comments --- src/levanter/data/text.py | 43 +++--------------------------- src/levanter/utils/fsspec_utils.py | 17 +++++++++++- 2 files changed, 19 insertions(+), 41 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 861a017b0..fdd935d82 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -10,7 +10,7 @@ from itertools import chain from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union -import braceexpand + import datasets import equinox as eqx import fsspec @@ -38,6 +38,7 @@ from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore from levanter.utils.hf_utils import num_cpus_used_by_tokenizer +from levanter.utils.fsspec_utils import fsspec_expand_glob silence_transformer_nag() # noqa @@ -378,20 +379,6 @@ def num_gpus(self) -> int: return 0 -def fsspec_expand_glob(url): - expanded_urls = braceexpand.braceexpand(url) - for expanded_url in expanded_urls: - if "*" in expanded_url: - fs = fsspec.core.url_to_fs(expanded_url)[0] - globbed = fs.glob(expanded_url) - # have to append the fs prefix back on - protocol, _ = fsspec.core.split_protocol(expanded_url) - if protocol is None: - yield from globbed - else: - yield from [f"{protocol}://{path}" for path in globbed] - else: - yield expanded_url def concatenate_and_group_texts( @@ -578,7 +565,7 @@ def tagged_eval_sets( @dataclass -class LMSupervisedDatasetConfig(LMDatasetSourceConfig): +class LMSupervisedDatasetConfig: """This class represents a dataset source with URLs or hf name/id.""" cache_dir: str = "cache/" @@ -589,30 +576,6 @@ class LMSupervisedDatasetConfig(LMDatasetSourceConfig): validation_urls: List[str] = () # type:ignore - # def token_seq_dataset( - # self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - # ) -> Optional[TokenSeqDataset]: - # cache = self.build_or_load_cache(split, monitors=monitors) - # if cache is None: - # return None - # return TokenSeqDataset(cache, seq_len) - - # def validation_set( - # self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - # ) -> Optional[TokenSeqDataset]: - # return self.token_seq_dataset("validation", seq_len, monitors) - - # def validation_sets( - # self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - # ) -> Mapping[str, AsyncDataset[np.ndarray]]: - # validation_set = self.validation_set(seq_len, monitors) - # if validation_set is not None: - # return {"": validation_set} - # else: - # return {} - - # Add tagged eval set with split for auxiliary and validation dataset - def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): sources = [example["input"] for example in batch] diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index 896ea8450..6a1341bff 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,5 +1,5 @@ import fsspec - +import braceexpand def exists(url, **kwargs) -> bool: """Check if a file exists on a remote filesystem.""" @@ -11,3 +11,18 @@ def mkdirs(path): """Create a directory and any necessary parent directories.""" fs, path = fsspec.core.url_to_fs(path) fs.makedirs(path, exist_ok=True) + +def fsspec_expand_glob(url): + expanded_urls = braceexpand.braceexpand(url) + for expanded_url in expanded_urls: + if "*" in expanded_url: + fs = fsspec.core.url_to_fs(expanded_url)[0] + globbed = fs.glob(expanded_url) + # have to append the fs prefix back on + protocol, _ = fsspec.core.split_protocol(expanded_url) + if protocol is None: + yield from globbed + else: + yield from [f"{protocol}://{path}" for path in globbed] + else: + yield expanded_url \ No newline at end of file From cf2c9e5b20714b2cdd051f60ff4213982fb4c497 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 9 Oct 2024 14:08:32 -0700 Subject: [PATCH 8/8] lint and minor --- scripts/launch_gpt2_small_fast_supervised_tpu.sh | 6 ++++++ scripts/launch_gpt2_small_fast_tpu.sh | 2 +- src/levanter/data/text.py | 2 -- src/levanter/utils/fsspec_utils.py | 4 +++- 4 files changed, 10 insertions(+), 4 deletions(-) create mode 100644 scripts/launch_gpt2_small_fast_supervised_tpu.sh diff --git a/scripts/launch_gpt2_small_fast_supervised_tpu.sh b/scripts/launch_gpt2_small_fast_supervised_tpu.sh new file mode 100644 index 000000000..df38aec99 --- /dev/null +++ b/scripts/launch_gpt2_small_fast_supervised_tpu.sh @@ -0,0 +1,6 @@ +# Launches the "gpt_small_fast" model on a TPU node + +python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ + python -m levanter.main.train_lm \ + --config_path config/gpt2_small_fast_supervised.yaml \ + --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* diff --git a/scripts/launch_gpt2_small_fast_tpu.sh b/scripts/launch_gpt2_small_fast_tpu.sh index df38aec99..0c09cdcfa 100644 --- a/scripts/launch_gpt2_small_fast_tpu.sh +++ b/scripts/launch_gpt2_small_fast_tpu.sh @@ -2,5 +2,5 @@ python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ python -m levanter.main.train_lm \ - --config_path config/gpt2_small_fast_supervised.yaml \ + --config_path config/gpt2_small_fast.yaml \ --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index fdd935d82..dfd16f844 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -379,8 +379,6 @@ def num_gpus(self) -> int: return 0 - - def concatenate_and_group_texts( encoding: BatchEncoding, seq_len: int, diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index 6a1341bff..452ab3d84 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,6 +1,7 @@ import fsspec import braceexpand + def exists(url, **kwargs) -> bool: """Check if a file exists on a remote filesystem.""" fs, path = fsspec.core.url_to_fs(url, **kwargs) @@ -12,6 +13,7 @@ def mkdirs(path): fs, path = fsspec.core.url_to_fs(path) fs.makedirs(path, exist_ok=True) + def fsspec_expand_glob(url): expanded_urls = braceexpand.braceexpand(url) for expanded_url in expanded_urls: @@ -25,4 +27,4 @@ def fsspec_expand_glob(url): else: yield from [f"{protocol}://{path}" for path in globbed] else: - yield expanded_url \ No newline at end of file + yield expanded_url