From 7bbf1c2abdbd36c7b3753f890f466989d43f7e4a Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 3 Dec 2024 13:54:37 -0800 Subject: [PATCH] pr --- src/levanter/eval_harness.py | 154 +++++++++++++++++++++++++++-------- 1 file changed, 122 insertions(+), 32 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 9f6faee23..834e06bd7 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -1,7 +1,21 @@ -# Code for running https://github.com/EleutherAI/lm-evaluation-harness inside Levanter runs -# References: -# https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py -# https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6 +""" +This module contains code for running the [EleutherAI LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) +inside Levanter runs. The EleutherAI LM Evaluation Harness is a tool for evaluating language models on a variety of tasks. + +The [run_lm_eval_harness][] function runs the EleutherAI LM Evaluation Harness on a given model and tasks, and returns the +results. + +It can also be used as a callback, via the [lm_eval_harness][] function. + +Note that Levanter does not support generation (use VLLM or something) and the [generate_until][] method is not implemented. +So we only support tasks that work with loglikelihood, which is most(?) of them. + +References: + +* https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py +* https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6 + +""" import dataclasses import functools import json @@ -18,6 +32,7 @@ import numpy as np import haliax +from haliax import NamedArray import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer @@ -74,9 +89,12 @@ async def get_batch(self, indices: Sequence[int]) -> List[LmExample]: out = [] pad_token_id = self.tokenizer.pad_token_id + # lm-harness specs that args are (context, completion) reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices] for context, completion in reqs: + # it's kinda annoying we run tokenization twice, but it's the easiest way to get the prompt length + # CF: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/model.py#L354 whole_enc = self.tokenizer(context + completion) context_enc = self.tokenizer(context) @@ -89,7 +107,15 @@ async def get_batch(self, indices: Sequence[int]) -> List[LmExample]: return out - def _truncate_or_pad(self, encoded, prompt_length): + def _truncate_or_pad(self, encoded: list[int], prompt_length: int): + """ + Truncate or pad the encoded sequence to the maximum length of the model. + Truncates from the beginning of the sequence, so that the completion is preserved. + + Returns: + Truncated or padded sequence and the prompt length. The prompt length can be shorter than the original + length if the input was truncated. + """ if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id @@ -120,7 +146,12 @@ def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, a self.axis_resources = axis_resources self.tokenizer = tokenizer - def _eval_loglikelihood(model: LmHeadModel, example: LmExample): + def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedArray, NamedArray]: + """ + Returns: + - loss: The negative log-likelihood of the completion. + - correct: Whether the completion is correct + """ logits = model(example.tokens, attn_mask=example.attn_mask) logits = logits.astype(jnp.float32) Pos = logits.resolve_axis(self.EvalPos.name) @@ -138,6 +169,7 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample): not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool) pred_targets = hax.argmax(logits, axis=model.Vocab) targets = hax.roll(example.tokens, -1, axis=Pos) + # "freebie" is the positions we don't need to predict (prompt or final token's next token) freebie = hax.logical_not(example.loss_mask * not_last_loss_mask) correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos) @@ -156,10 +188,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ # pad requests to be a multiple of the batch size initial_length = len(requests) - dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) - requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size) - assert len(requests) % self.EvalBatch.size == 0 - dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) + dataset = self._pad_dataset_to_batch_size(requests) mesh = haliax.partitioning._get_mesh() @@ -178,6 +207,13 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: return result + def _pad_dataset_to_batch_size(self, requests): + dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) + requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size) + assert len(requests) % self.EvalBatch.size == 0 + dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) + return dataset + def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() @@ -201,18 +237,37 @@ class TaskConfig: nb that LM Eval Harness has its own TaskConfig, but its defaults are not the same as just passing in a dict, and we want the behavior of passing in a dict. + Nones are not included in the dictionary representation, and LM Eval Harness will use its own defaults for any + missing values. + + Docs are copied from the LM Eval Harness task guide. The LM Eval Harness task guide is the authoritative source + for what these fields do. They were copied as of 2024-12-03. + See Also: - [LM Eval Harness TaskConfig](https://github.com/EleutherAI/lm-evaluation-harness/blob/0ef7548d7c3f01108e7c12900a5e5eb4b4a668f7/lm_eval/api/task.py#L55) + * [LM Eval Harness TaskConfig](https://github.com/EleutherAI/lm-evaluation-harness/blob/0ef7548d7c3f01108e7c12900a5e5eb4b4a668f7/lm_eval/api/task.py#L55) + * [LM Eval Harness task guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md#parameters) """ task: str + """ The name of the task to run.""" task_alias: str | None = None + """ An alias for the task. We log this name to wandb.""" num_fewshot: int | None = None use_prompt: str | None = None + """ Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.""" description: str | None = None + """An optional prepended Jinja2 template or string which will be prepended to the few-shot examples passed into the model, often describing the task or providing instructions to a model, such as "The following are questions (with answers) about {{subject}}.\n\n". No delimiters or spacing are inserted between the description and the first few-shot example.""" target_delimiter: str | None = None + """String to insert between input and target output for the datapoint being tested. defaults to " " """ fewshot_delimiter: str | None = None + """ String to insert between few-shot examples. defaults to "\\n\\n" """ + doc_to_text: str | None = None + """Jinja2 template string to process a sample into the appropriate input for the model.""" + doct_to_target: str | None = None + """Jinja2 template string to process a sample into the appropriate target for the model.""" + doc_to_choice: str | None = None + """Jinja2 template string to process a sample into a list of possible string choices for multiple_choice tasks. """ def to_dict(self): base_dict = dataclasses.asdict(self) @@ -221,50 +276,66 @@ def to_dict(self): @dataclass(frozen=True) class LmEvalHarnessConfig: - task_spec: list[TaskConfig | str] | None = None + task_spec: list[TaskConfig | str] max_examples: int | None = None max_eval_length: int | None = None log_samples: bool = False - def task_spec_or_default(self) -> list[str | dict]: - if self.task_spec is None: - return ["hellaswag", "piqa"] + def to_task_spec(self) -> list[str | dict]: return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec] def to_task_dict(self) -> dict: + """ + Convert the task spec to a dictionary that the LM Eval Harness expects. + + This is a bit more complex than we'd like, because we want to run e.g. Hellaswag 0-shot and 10-shot in the same + run, and LM Eval Harness doesn't seem to want to do that by default. So we need to do some hacky stuff to make + it work. + """ import lm_eval.tasks as tasks manager = tasks.TaskManager() # we need to do it this way b/c i can't figure out how to run e.g. hellaswag 0 shot and 10 shot in a single run this_tasks = {} - for task in self.task_spec_or_default(): + for task in self.to_task_spec(): try: if isinstance(task, str): this_tasks.update(tasks.get_task_dict(task, manager)) else: our_name = task.get("task_alias", task["task"]) if isinstance(task, dict) else task our_name = our_name.replace(" ", "_") - task_dict = tasks.get_task_dict([task], manager) - this_task = task_dict.popitem()[1] - # hacky, but this allows us to run multiple instances of the same task with different fewshot settings - this_task.config.task = our_name + this_task = self._get_task_and_rename(manager, our_name, task) this_tasks[our_name] = this_task except Exception: logger.exception(f"Failed to load task {task}") raise ValueError(f"Failed to load task {task}") return this_tasks + def _get_task_and_rename(self, manager, our_name, task: dict | str): + """ + Get a task from the task manager and rename it to our_name. + LM Eval Harness doesn't seem to want to run multiple instances of the same task with different fewshot settings, + (or other differences) so we need to hack around that. + """ + import lm_eval.tasks as tasks + + task_dict = tasks.get_task_dict([task], manager) + this_task = task_dict.popitem()[1] + # hacky, but this allows us to run multiple instances of the same task with different fewshot settings + this_task.config.task = our_name + return this_task + @dataclass(frozen=True) class EvalHarnessMainConfig: + eval_harness: LmEvalHarnessConfig tokenizer: str checkpoint_path: str checkpoint_is_hf: bool = False + """If True, the checkpoint is a HuggingFace checkpoint. Otherwise, it is a Levanter checkpoint.""" trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig) model: LmConfig = dataclasses.field(default_factory=Gpt2Config) - eval_harness: LmEvalHarnessConfig = dataclasses.field(default_factory=LmEvalHarnessConfig) - @property def EvalBatch(self): return self.trainer.EvalBatch @@ -289,13 +360,32 @@ def run_lm_eval_harness( return outputs -def _actually_run_eval_harness(config: LmEvalHarnessConfig, model, tasks_to_run, tokenizer, EvalBatch, axis_resources): +def _actually_run_eval_harness( + config: LmEvalHarnessConfig, model: LM, tasks_to_run: dict, tokenizer, EvalBatch, axis_resources +): + """ + Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used + by the main function and the callback function. + + Args: + config: + model: + tasks_to_run: + tokenizer: The tokenizer to use + EvalBatch: The batch axis for compute + axis_resources: axis mapping for compute + + Returns: + + + + """ max_examples = config.max_examples max_eval_length = config.max_eval_length EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) - # we always log_samples here and filter out the samples later if we don't want them + # we always set log_samples here and filter out the samples later if we don't want them outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True) averages = _compute_averages(outputs) @@ -350,7 +440,7 @@ def _compute_averages(outputs): return averages -NAT_TO_BIT = 1 / np.log(2) +BITS_PER_NAT = 1 / np.log(2) # eval_harness isn't consistent enough for this to actually be workable # def _compute_extra_metrics(samples): @@ -488,7 +578,7 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): logger.info("Finished running LM eval harness") # log the results as json - with open("lm_eval_results.json", "w") as f: + with open("lm_eval_harness_results.json", "w") as f: json.dump(outputs, f, indent=2) # also write to stdout @@ -496,7 +586,7 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): print(json.dumps(outputs, indent=2), flush=True) # also log the results - levanter.tracker.current_tracker().log_artifact("lm_eval_results.json") + levanter.tracker.current_tracker().log_artifact("lm_eval_harness_results.json", name="lm_eval_harness_results") log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) return outputs @@ -509,6 +599,7 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter. to_log = {} for task_name, task_results in report["results"].items(): for metric_name, metric_value in task_results.items(): + # remove the ",none" suffix, which eval-harness adds by default for some reason if metric_name.endswith(",none"): metric_name = metric_name[:-5] @@ -530,10 +621,8 @@ def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_reso tasks_to_run = config.to_task_dict() def lm_eval_harness(step: StepInfo, force=False): - # if step.step == 0 and not force: - # return # don't run eval on the first step - - print(config.task_spec_or_default()) + if step.step == 0 and not force: + return # don't run eval on the first step model = inference_mode(step.model, True) outputs = _actually_run_eval_harness( @@ -546,12 +635,13 @@ def lm_eval_harness(step: StepInfo, force=False): ) if jax.process_index() == 0: + # don't delete b/c wandb will sometimes defer upload with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: import json json.dump(outputs, f) levanter.tracker.current_tracker().log_artifact( - f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output" + f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output" ) log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())