From 1d646c7ff9cb02d7478c9d7f02b5c900dd58605e Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 12 Dec 2024 14:28:21 -0800 Subject: [PATCH] switch to doing lm eval harness evals in bf16 (#841) --- src/levanter/eval_harness.py | 31 ++++++++++++++++++++++++------- src/levanter/main/train_lm.py | 4 +++- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index b9f4381fe..cb15125e0 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -29,6 +29,7 @@ import equinox as eqx import jax import jax.numpy as jnp +import jmp import numpy as np import haliax @@ -139,13 +140,22 @@ def _truncate_or_pad(self, encoded: list[int], prompt_length: int): class LevanterHarnessLM(LM): - def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): + def __init__( + self, + EvalBatch: hax.Axis, + EvalPos: hax.Axis, + model: LmHeadModel, + axis_resources, + tokenizer, + mp: jmp.Policy | None, + ): super().__init__() self.EvalBatch = EvalBatch self.EvalPos = EvalPos self.model = model self.axis_resources = axis_resources self.tokenizer = tokenizer + self.mp = mp def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedArray, NamedArray]: """ @@ -153,6 +163,10 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr - loss: The negative log-likelihood of the completion. - correct: Whether the completion is correct """ + + if self.mp is not None: + model = self.mp.cast_to_compute(model) + logits = model(example.tokens, attn_mask=example.attn_mask) logits = logits.astype(jnp.float32) Pos = logits.resolve_axis(self.EvalPos.name) @@ -352,11 +366,11 @@ def run_lm_eval_harness( tokenizer, EvalBatch, axis_resources, + mp: jmp.Policy | None, ) -> dict: - # tasks_to_run = tasks.get_task_dict(config.task_spec_or_default(), tasks.TaskManager()) tasks_to_run = config.to_task_dict() - outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources) + outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources, mp) return outputs @@ -368,6 +382,7 @@ def _actually_run_eval_harness( tokenizer: HfTokenizer, EvalBatch: haliax.Axis, axis_resources: ResourceMapping, + mp: jmp.Policy | None, ): """ Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used @@ -382,7 +397,7 @@ def _actually_run_eval_harness( max_eval_length = config.max_eval_length EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) - harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) + harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp) # we always set log_samples here and filter out the samples later if we don't want them outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True) @@ -572,6 +587,7 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): tokenizer, config.EvalBatch, axis_resources=compute_axis_mapping, + mp=config.trainer.mp, ) logger.info("Finished running LM eval harness") @@ -615,7 +631,7 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter. tracker.log(to_log, step=None) -def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): +def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources, mp: jmp.Policy | None): tasks_to_run = config.to_task_dict() def lm_eval_harness(step: StepInfo, force=False): @@ -630,9 +646,12 @@ def lm_eval_harness(step: StepInfo, force=False): tokenizer, EvalBatch, axis_resources, + mp, ) if jax.process_index() == 0: + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + # don't delete b/c wandb will sometimes defer upload with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: import json @@ -642,8 +661,6 @@ def lm_eval_harness(step: StepInfo, force=False): f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output" ) - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) - return lm_eval_harness diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 75be8d206..15da92cfa 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -263,7 +263,9 @@ def main(config: TrainLmConfig): if config.eval_harness is not None: eval_harness = config.eval_harness trainer.add_hook( - levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), + levanter.eval_harness.lm_eval_harness( + eval_harness, tokenizer, EvalBatch, compute_axis_mapping, trainer.mp + ), every=config.eval_harness_steps, )