Skip to content

Commit

Permalink
pr
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 3, 2024
1 parent 5a4e6ce commit 7bbf1c2
Showing 1 changed file with 122 additions and 32 deletions.
154 changes: 122 additions & 32 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
# Code for running https://github.com/EleutherAI/lm-evaluation-harness inside Levanter runs
# References:
# https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py
# https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6
"""
This module contains code for running the [EleutherAI LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)
inside Levanter runs. The EleutherAI LM Evaluation Harness is a tool for evaluating language models on a variety of tasks.
The [run_lm_eval_harness][] function runs the EleutherAI LM Evaluation Harness on a given model and tasks, and returns the
results.
It can also be used as a callback, via the [lm_eval_harness][] function.
Note that Levanter does not support generation (use VLLM or something) and the [generate_until][] method is not implemented.
So we only support tasks that work with loglikelihood, which is most(?) of them.
References:
* https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py
* https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6
"""
import dataclasses
import functools
import json
Expand All @@ -18,6 +32,7 @@
import numpy as np

import haliax
from haliax import NamedArray

import levanter.tracker
from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
Expand Down Expand Up @@ -74,9 +89,12 @@ async def get_batch(self, indices: Sequence[int]) -> List[LmExample]:
out = []
pad_token_id = self.tokenizer.pad_token_id

# lm-harness specs that args are (context, completion)
reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices]

for context, completion in reqs:
# it's kinda annoying we run tokenization twice, but it's the easiest way to get the prompt length
# CF: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/model.py#L354
whole_enc = self.tokenizer(context + completion)
context_enc = self.tokenizer(context)

Expand All @@ -89,7 +107,15 @@ async def get_batch(self, indices: Sequence[int]) -> List[LmExample]:

return out

def _truncate_or_pad(self, encoded, prompt_length):
def _truncate_or_pad(self, encoded: list[int], prompt_length: int):
"""
Truncate or pad the encoded sequence to the maximum length of the model.
Truncates from the beginning of the sequence, so that the completion is preserved.
Returns:
Truncated or padded sequence and the prompt length. The prompt length can be shorter than the original
length if the input was truncated.
"""
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

Expand Down Expand Up @@ -120,7 +146,12 @@ def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, a
self.axis_resources = axis_resources
self.tokenizer = tokenizer

def _eval_loglikelihood(model: LmHeadModel, example: LmExample):
def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedArray, NamedArray]:
"""
Returns:
- loss: The negative log-likelihood of the completion.
- correct: Whether the completion is correct
"""
logits = model(example.tokens, attn_mask=example.attn_mask)
logits = logits.astype(jnp.float32)
Pos = logits.resolve_axis(self.EvalPos.name)
Expand All @@ -138,6 +169,7 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample):
not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool)
pred_targets = hax.argmax(logits, axis=model.Vocab)
targets = hax.roll(example.tokens, -1, axis=Pos)
# "freebie" is the positions we don't need to predict (prompt or final token's next token)
freebie = hax.logical_not(example.loss_mask * not_last_loss_mask)
correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos)

Expand All @@ -156,10 +188,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
"""
# pad requests to be a multiple of the batch size
initial_length = len(requests)
dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests))
requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size)
assert len(requests) % self.EvalBatch.size == 0
dataset = EvalDataset(self.EvalPos, self.tokenizer, requests)
dataset = self._pad_dataset_to_batch_size(requests)

mesh = haliax.partitioning._get_mesh()

Expand All @@ -178,6 +207,13 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:

return result

def _pad_dataset_to_batch_size(self, requests):
dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests))
requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size)
assert len(requests) % self.EvalBatch.size == 0
dataset = EvalDataset(self.EvalPos, self.tokenizer, requests)
return dataset

def loglikelihood_rolling(self, requests) -> List[Tuple[float]]:
raise NotImplementedError()

Expand All @@ -201,18 +237,37 @@ class TaskConfig:
nb that LM Eval Harness has its own TaskConfig, but its defaults are not the same as just passing in
a dict, and we want the behavior of passing in a dict.
Nones are not included in the dictionary representation, and LM Eval Harness will use its own defaults for any
missing values.
Docs are copied from the LM Eval Harness task guide. The LM Eval Harness task guide is the authoritative source
for what these fields do. They were copied as of 2024-12-03.
See Also:
[LM Eval Harness TaskConfig](https://github.com/EleutherAI/lm-evaluation-harness/blob/0ef7548d7c3f01108e7c12900a5e5eb4b4a668f7/lm_eval/api/task.py#L55)
* [LM Eval Harness TaskConfig](https://github.com/EleutherAI/lm-evaluation-harness/blob/0ef7548d7c3f01108e7c12900a5e5eb4b4a668f7/lm_eval/api/task.py#L55)
* [LM Eval Harness task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md#parameters)
"""

task: str
""" The name of the task to run."""
task_alias: str | None = None
""" An alias for the task. We log this name to wandb."""
num_fewshot: int | None = None

use_prompt: str | None = None
""" Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice."""
description: str | None = None
"""An optional prepended Jinja2 template or string which will be prepended to the few-shot examples passed into the model, often describing the task or providing instructions to a model, such as "The following are questions (with answers) about {{subject}}.\n\n". No delimiters or spacing are inserted between the description and the first few-shot example."""
target_delimiter: str | None = None
"""String to insert between input and target output for the datapoint being tested. defaults to " " """
fewshot_delimiter: str | None = None
""" String to insert between few-shot examples. defaults to "\\n\\n" """
doc_to_text: str | None = None
"""Jinja2 template string to process a sample into the appropriate input for the model."""
doct_to_target: str | None = None
"""Jinja2 template string to process a sample into the appropriate target for the model."""
doc_to_choice: str | None = None
"""Jinja2 template string to process a sample into a list of possible string choices for multiple_choice tasks. """

def to_dict(self):
base_dict = dataclasses.asdict(self)
Expand All @@ -221,50 +276,66 @@ def to_dict(self):

@dataclass(frozen=True)
class LmEvalHarnessConfig:
task_spec: list[TaskConfig | str] | None = None
task_spec: list[TaskConfig | str]
max_examples: int | None = None
max_eval_length: int | None = None
log_samples: bool = False

def task_spec_or_default(self) -> list[str | dict]:
if self.task_spec is None:
return ["hellaswag", "piqa"]
def to_task_spec(self) -> list[str | dict]:
return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec]

def to_task_dict(self) -> dict:
"""
Convert the task spec to a dictionary that the LM Eval Harness expects.
This is a bit more complex than we'd like, because we want to run e.g. Hellaswag 0-shot and 10-shot in the same
run, and LM Eval Harness doesn't seem to want to do that by default. So we need to do some hacky stuff to make
it work.
"""
import lm_eval.tasks as tasks

manager = tasks.TaskManager()
# we need to do it this way b/c i can't figure out how to run e.g. hellaswag 0 shot and 10 shot in a single run
this_tasks = {}
for task in self.task_spec_or_default():
for task in self.to_task_spec():
try:
if isinstance(task, str):
this_tasks.update(tasks.get_task_dict(task, manager))
else:
our_name = task.get("task_alias", task["task"]) if isinstance(task, dict) else task
our_name = our_name.replace(" ", "_")
task_dict = tasks.get_task_dict([task], manager)
this_task = task_dict.popitem()[1]
# hacky, but this allows us to run multiple instances of the same task with different fewshot settings
this_task.config.task = our_name
this_task = self._get_task_and_rename(manager, our_name, task)
this_tasks[our_name] = this_task
except Exception:
logger.exception(f"Failed to load task {task}")
raise ValueError(f"Failed to load task {task}")
return this_tasks

def _get_task_and_rename(self, manager, our_name, task: dict | str):
"""
Get a task from the task manager and rename it to our_name.
LM Eval Harness doesn't seem to want to run multiple instances of the same task with different fewshot settings,
(or other differences) so we need to hack around that.
"""
import lm_eval.tasks as tasks

task_dict = tasks.get_task_dict([task], manager)
this_task = task_dict.popitem()[1]
# hacky, but this allows us to run multiple instances of the same task with different fewshot settings
this_task.config.task = our_name
return this_task


@dataclass(frozen=True)
class EvalHarnessMainConfig:
eval_harness: LmEvalHarnessConfig
tokenizer: str
checkpoint_path: str
checkpoint_is_hf: bool = False
"""If True, the checkpoint is a HuggingFace checkpoint. Otherwise, it is a Levanter checkpoint."""
trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig)
model: LmConfig = dataclasses.field(default_factory=Gpt2Config)

eval_harness: LmEvalHarnessConfig = dataclasses.field(default_factory=LmEvalHarnessConfig)

@property
def EvalBatch(self):
return self.trainer.EvalBatch
Expand All @@ -289,13 +360,32 @@ def run_lm_eval_harness(
return outputs


def _actually_run_eval_harness(config: LmEvalHarnessConfig, model, tasks_to_run, tokenizer, EvalBatch, axis_resources):
def _actually_run_eval_harness(
config: LmEvalHarnessConfig, model: LM, tasks_to_run: dict, tokenizer, EvalBatch, axis_resources
):
"""
Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used
by the main function and the callback function.
Args:
config:
model:
tasks_to_run:
tokenizer: The tokenizer to use
EvalBatch: The batch axis for compute
axis_resources: axis mapping for compute
Returns:
"""
max_examples = config.max_examples
max_eval_length = config.max_eval_length

EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length)
harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer)
# we always log_samples here and filter out the samples later if we don't want them
# we always set log_samples here and filter out the samples later if we don't want them
outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True)

averages = _compute_averages(outputs)
Expand Down Expand Up @@ -350,7 +440,7 @@ def _compute_averages(outputs):
return averages


NAT_TO_BIT = 1 / np.log(2)
BITS_PER_NAT = 1 / np.log(2)

# eval_harness isn't consistent enough for this to actually be workable
# def _compute_extra_metrics(samples):
Expand Down Expand Up @@ -488,15 +578,15 @@ def run_eval_harness_main(config: EvalHarnessMainConfig):

logger.info("Finished running LM eval harness")
# log the results as json
with open("lm_eval_results.json", "w") as f:
with open("lm_eval_harness_results.json", "w") as f:
json.dump(outputs, f, indent=2)

# also write to stdout
if jax.process_index() == 0:
print(json.dumps(outputs, indent=2), flush=True)

# also log the results
levanter.tracker.current_tracker().log_artifact("lm_eval_results.json")
levanter.tracker.current_tracker().log_artifact("lm_eval_harness_results.json", name="lm_eval_harness_results")
log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())

return outputs
Expand All @@ -509,6 +599,7 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter.
to_log = {}
for task_name, task_results in report["results"].items():
for metric_name, metric_value in task_results.items():
# remove the ",none" suffix, which eval-harness adds by default for some reason
if metric_name.endswith(",none"):
metric_name = metric_name[:-5]

Expand All @@ -530,10 +621,8 @@ def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_reso
tasks_to_run = config.to_task_dict()

def lm_eval_harness(step: StepInfo, force=False):
# if step.step == 0 and not force:
# return # don't run eval on the first step

print(config.task_spec_or_default())
if step.step == 0 and not force:
return # don't run eval on the first step

model = inference_mode(step.model, True)
outputs = _actually_run_eval_harness(
Expand All @@ -546,12 +635,13 @@ def lm_eval_harness(step: StepInfo, force=False):
)

if jax.process_index() == 0:
# don't delete b/c wandb will sometimes defer upload
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
import json

json.dump(outputs, f)
levanter.tracker.current_tracker().log_artifact(
f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output"
f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output"
)

log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())
Expand Down

0 comments on commit 7bbf1c2

Please sign in to comment.