Skip to content

Commit

Permalink
Allowing internal supervised eval to work without separate eval set
Browse files Browse the repository at this point in the history
  • Loading branch information
TheQuantumFractal committed Nov 5, 2024
1 parent 45d3e70 commit 957768d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ class LMSupervisedDatasetConfig:

validation_urls: List[str] = () # type:ignore


def preprocess_supervised_example(
batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str
) -> dict:
Expand Down Expand Up @@ -631,7 +630,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
input_field = config.input_field
output_field = config.output_field

output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)}
output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)}

dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore
dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def main(config: TrainLmConfig):

levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)})

max_eval_examples_per_ds = config.trainer.max_eval_batches
if max_eval_examples_per_ds is not None:
max_eval_examples_per_ds *= config.trainer.eval_batch_size

if len(tagged_eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")
else:
max_eval_examples_per_ds = config.trainer.max_eval_batches
if max_eval_examples_per_ds is not None:
max_eval_examples_per_ds *= config.trainer.eval_batch_size

causal_datasets = [
(CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags)
for ds, tags in tagged_eval_datasets
Expand Down

0 comments on commit 957768d

Please sign in to comment.