diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml index d71e1267e..93675366d 100644 --- a/config/gpt2_small_fast_supervised.yaml +++ b/config/gpt2_small_fast_supervised.yaml @@ -15,6 +15,7 @@ data: supervised_data: validation_urls: - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" + - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz" cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" input_field: "input" output_field: "output" diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..f2bea44b2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -631,7 +631,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index fe5e5dd35..79095d601 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -160,13 +160,13 @@ def main(config: TrainLmConfig): levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: - max_eval_examples_per_ds = config.trainer.max_eval_batches - if max_eval_examples_per_ds is not None: - max_eval_examples_per_ds *= config.trainer.eval_batch_size - causal_datasets = [ (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets