Skip to content

Commit

Permalink
switch to doing lm eval harness evals in bf16 (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Dec 12, 2024
1 parent 1d63849 commit 1d646c7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
31 changes: 24 additions & 7 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import jmp
import numpy as np

import haliax
Expand Down Expand Up @@ -139,20 +140,33 @@ def _truncate_or_pad(self, encoded: list[int], prompt_length: int):


class LevanterHarnessLM(LM):
def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, axis_resources, tokenizer):
def __init__(
self,
EvalBatch: hax.Axis,
EvalPos: hax.Axis,
model: LmHeadModel,
axis_resources,
tokenizer,
mp: jmp.Policy | None,
):
super().__init__()
self.EvalBatch = EvalBatch
self.EvalPos = EvalPos
self.model = model
self.axis_resources = axis_resources
self.tokenizer = tokenizer
self.mp = mp

def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedArray, NamedArray]:
"""
Returns:
- loss: The negative log-likelihood of the completion.
- correct: Whether the completion is correct
"""

if self.mp is not None:
model = self.mp.cast_to_compute(model)

logits = model(example.tokens, attn_mask=example.attn_mask)
logits = logits.astype(jnp.float32)
Pos = logits.resolve_axis(self.EvalPos.name)
Expand Down Expand Up @@ -352,11 +366,11 @@ def run_lm_eval_harness(
tokenizer,
EvalBatch,
axis_resources,
mp: jmp.Policy | None,
) -> dict:
# tasks_to_run = tasks.get_task_dict(config.task_spec_or_default(), tasks.TaskManager())
tasks_to_run = config.to_task_dict()

outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources)
outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources, mp)

return outputs

Expand All @@ -368,6 +382,7 @@ def _actually_run_eval_harness(
tokenizer: HfTokenizer,
EvalBatch: haliax.Axis,
axis_resources: ResourceMapping,
mp: jmp.Policy | None,
):
"""
Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used
Expand All @@ -382,7 +397,7 @@ def _actually_run_eval_harness(
max_eval_length = config.max_eval_length

EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length)
harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer)
harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp)
# we always set log_samples here and filter out the samples later if we don't want them
outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True)

Expand Down Expand Up @@ -572,6 +587,7 @@ def run_eval_harness_main(config: EvalHarnessMainConfig):
tokenizer,
config.EvalBatch,
axis_resources=compute_axis_mapping,
mp=config.trainer.mp,
)

logger.info("Finished running LM eval harness")
Expand Down Expand Up @@ -615,7 +631,7 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter.
tracker.log(to_log, step=None)


def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources):
def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources, mp: jmp.Policy | None):
tasks_to_run = config.to_task_dict()

def lm_eval_harness(step: StepInfo, force=False):
Expand All @@ -630,9 +646,12 @@ def lm_eval_harness(step: StepInfo, force=False):
tokenizer,
EvalBatch,
axis_resources,
mp,
)

if jax.process_index() == 0:
log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())

# don't delete b/c wandb will sometimes defer upload
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
import json
Expand All @@ -642,8 +661,6 @@ def lm_eval_harness(step: StepInfo, force=False):
f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output"
)

log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())

return lm_eval_harness


Expand Down
4 changes: 3 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def main(config: TrainLmConfig):
if config.eval_harness is not None:
eval_harness = config.eval_harness
trainer.add_hook(
levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping),
levanter.eval_harness.lm_eval_harness(
eval_harness, tokenizer, EvalBatch, compute_axis_mapping, trainer.mp
),
every=config.eval_harness_steps,
)

Expand Down

0 comments on commit 1d646c7

Please sign in to comment.