diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..7fb844d8f 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -575,7 +575,6 @@ class LMSupervisedDatasetConfig: validation_urls: List[str] = () # type:ignore - def preprocess_supervised_example( batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str ) -> dict: @@ -631,7 +630,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index fe5e5dd35..79095d601 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -160,13 +160,13 @@ def main(config: TrainLmConfig): levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: - max_eval_examples_per_ds = config.trainer.max_eval_batches - if max_eval_examples_per_ds is not None: - max_eval_examples_per_ds *= config.trainer.eval_batch_size - causal_datasets = [ (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets