From bcfc225a80a6ded59b0e901f993322d3c4ea2b76 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 23 Nov 2024 23:55:30 -0800 Subject: [PATCH] eval_harness is about there --- config/harness/eval_marin_dclm_ckpt.yaml | 27 +++ src/levanter/eval_harness.py | 201 +++++++++++++++-------- 2 files changed, 157 insertions(+), 71 deletions(-) create mode 100644 config/harness/eval_marin_dclm_ckpt.yaml diff --git a/config/harness/eval_marin_dclm_ckpt.yaml b/config/harness/eval_marin_dclm_ckpt.yaml new file mode 100644 index 000000000..f503fcb2c --- /dev/null +++ b/config/harness/eval_marin_dclm_ckpt.yaml @@ -0,0 +1,27 @@ +eval_harness: + task_spec: ["hellaswag"] +# max_examples: 9984 # this is the max that ends up being divisible by 512 after expansion + max_examples: 8 # this is the max that ends up being divisible by 512 after expansion + max_eval_length: 128 +#tokenizer: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 +#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/hf/step-715001/ +#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/step-510000/ +#tokenizer: "EleutherAI/gpt-neox-20b" +tokenizer: meta-llama/Meta-Llama-3-8B +model: + type: llama +#checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930 +checkpoint_path: meta-llama/Meta-Llama-3-8B +checkpoint_is_hf: true +trainer: + mp: f32 + profiler: true + + per_device_parallelism: -1 + train_batch_size: 512 + + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + ray: + auto_start_cluster: false diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 5a72f856e..bde8f688b 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -3,21 +3,23 @@ # https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py # https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6 import dataclasses +import functools import json import logging import typing -import warnings from dataclasses import dataclass from functools import cached_property -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple import equinox as eqx import jax import jax.numpy as jnp -import transformers -from levanter.compat.hf_checkpoints import HFCheckpointConverter +import haliax + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer from levanter.models.gpt2 import Gpt2Config +from levanter.models.loss import next_token_loss try: @@ -33,15 +35,14 @@ from tqdm import tqdm import haliax as hax -from haliax.nn import cross_entropy_loss from haliax.partitioning import round_axis_for_partitioning import levanter.config from levanter.checkpoint import load_checkpoint -from levanter.data import batched +from levanter.data import AsyncDataset, DataLoader from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import TrainerConfig -from levanter.utils.jax_utils import stack_tree, use_cpu_device +from levanter.utils.jax_utils import use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -76,94 +77,133 @@ class _RequestType: FINISHED = 3 +@functools.partial(jax.jit, static_argnums=(0, 3)) +def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): + tokens = hax.named(tokens, Pos) + return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) + + +class EvalDataset(AsyncDataset[LmExample]): + def __init__(self, Pos, tokenizer, examples: list[Instance]): + super().__init__() + self.examples = examples + self.Pos = Pos + self.tokenizer = tokenizer + + async def async_len(self) -> int: + return len(self.examples) + + async def final_length_is_known(self) -> bool: + return True + + def is_finite(self) -> bool: + return True + + async def current_len(self) -> Optional[int]: + return len(self.examples) + + async def get_batch(self, indices: Sequence[int]) -> List[LmExample]: + out = [] + pad_token_id = self.tokenizer.pad_token_id + + reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices] + + for context, completion in reqs: + whole_enc = self.tokenizer(context + completion) + context_enc = self.tokenizer(context) + + context_enc_len = len(context_enc["input_ids"]) + + tokens, length = self._truncate_or_pad(whole_enc, context_enc_len) + example = _jit_create_example(self.Pos, tokens, length, pad_token_id) + + out.append(example) + + return out + + def _truncate_or_pad(self, encoded, prompt_length): + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + ex_pad = self.tokenizer.pad( + encoded, + padding="max_length", + max_length=self.Pos.size, + return_tensors="np", + ) + + truncated = ex_pad["input_ids"][-self.Pos.size :] + # if we truncated the prompt, we need to adjust the prompt length + if len(truncated) < len(encoded): + prompt_length -= len(encoded) - len(truncated) + if prompt_length < 0: + prompt_length = 0 + logger.warning("Prompt length is negative after truncation. Setting to 0.") + + return truncated, prompt_length + + class LevanterHarnessLM(LM): - def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): + def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): super().__init__() self.EvalBatch = EvalBatch + self.EvalPos = EvalPos self.model = model self.axis_resources = axis_resources self.tokenizer = tokenizer def _eval_loglikelihood(model: LmHeadModel, example: LmExample): - logits = model(example.tokens) + logits = model(example.tokens, attn_mask=example.attn_mask) + logits = logits.astype(jnp.float32) + Pos = logits.resolve_axis(self.EvalPos.name) + + loss = next_token_loss( + Pos=Pos, + Vocab=model.Vocab, + logits=logits, + true_ids=example.tokens, + loss_mask=example.loss_mask, + reduction=hax.sum, + reduction_axis=Pos, + ) - targets = hax.roll(example.tokens, -1, axis=model.Pos.name) - target_y = hax.nn.one_hot(targets, model.Vocab, dtype=logits.dtype) - loss = cross_entropy_loss(logits, model.Vocab, target_y, where=example.loss_mask, reduction_axis=model.Pos) - # to tell if we got the right answer, we want to check that argmax(logits) == tokens wherever loss_mask is 1 + not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool) pred_targets = hax.argmax(logits, axis=model.Vocab) - correct = hax.all(hax.equal(pred_targets, targets) | hax.logical_not(example.loss_mask), axis=model.Pos) + targets = hax.roll(example.tokens, -1, axis=Pos) + freebie = hax.logical_not(example.loss_mask * not_last_loss_mask) + correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos) - return loss, correct + return -loss, correct # no sharded outputs self._jit_loglikelihood = hax.named_jit( _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} ) - def _stack_batch(self, examples): - return stack_tree(self.EvalBatch, examples, pad_to_batch_size=True) - def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ Compute log-likelihood of generating a continuation from a context. Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. - Args: - requests: - - Returns: - """ + dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) - contexts = self.tokenizer([req.args[0] for req in requests])["input_ids"] - completions = self.tokenizer([req.args[1] for req in requests])["input_ids"] - - examples: list[LmExample] = [] - - @hax.named_jit - def _jit_create_example(tokens, prompt_len): - tokens = hax.named(tokens, self.model.Pos) - return LmExample.from_prompt_and_completion( - self.model.Pos, tokens, prompt_len, ignore_id=self.tokenizer.pad_token_id - ) + mesh = haliax.partitioning._get_mesh() - # TODO: offload this to an evalbatchloader - for context, completion in zip(tqdm(contexts, desc="Creating examples"), completions): - tokens, length = self._truncate_or_pad(context, completion) - tokens = jnp.array(tokens) - length = jnp.array(length) - example = _jit_create_example(tokens, length) - examples.append(example) + loader = DataLoader( + self.EvalBatch, dataset, max_buffered_batches=1024, mesh=mesh, axis_resources=self.axis_resources + ) result: list[tuple[float, bool]] = [] - for batch in batched(tqdm(examples, desc="examples", leave=False), self.EvalBatch.size): - logger.info("Processing batch") - batch_example = self._stack_batch(batch) - # batch_example = jax.device_put(batch_example, jax.local_devices()[0]) - out_lls, out_correct = self._jit_loglikelihood(self.model, batch_example) + for batch in tqdm(loader, desc="Loglikelihood", unit="ba"): + out_lls, out_correct = self._jit_loglikelihood(self.model, batch) result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # skip padding - result = result[: len(examples)] + result = result[: len(requests)] return result - def _truncate_or_pad(self, context, completion): - max_len = self.model.Pos.size - if len(completion) > max_len: - warnings.warn(f"Completion is longer than max length {max_len}. Truncating.") - completion = completion[:max_len] - pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id - - if len(context) + len(completion) > max_len: - context = context[-(max_len - len(completion)) :] - else: - # right pad with padding token - context = context + [pad_token_id] * (max_len - len(context) - len(completion)) - - return jnp.array(context + completion), len(context) - def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() @@ -171,8 +211,17 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() -def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_resources, max_examples=None) -> dict: - harness = LevanterHarnessLM(EvalBatch, model, axis_resources, tokenizer) +def run_lm_eval_harness( + model, + task_spec: list[str], + tokenizer, + EvalBatch, + axis_resources, + max_examples: int | None = None, + max_eval_length: int | None = None, +) -> dict: + EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) + harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) @@ -181,13 +230,14 @@ def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_ @dataclass(frozen=True) class LmEvalHarnessConfig: - task_spec: Optional[list[str]] = None - max_examples: Optional[int] = None + task_spec: list[str] | None = None + max_examples: int | None = None + max_eval_length: int | None = None def task_spec_or_default(self): return self.task_spec or [ # "lambada", - # "piqa", + "piqa", "hellaswag", # "winogrande", # "mathqa", @@ -218,7 +268,7 @@ def EvalBatch(self): @cached_property def the_tokenizer(self): - return transformers.AutoTokenizer.from_pretrained(self.tokenizer) + return load_tokenizer(self.tokenizer) def run_eval_harness_main(config: EvalHarnessConfig): @@ -244,10 +294,10 @@ def run_eval_harness_main(config: EvalHarnessConfig): # initialize the model if config.checkpoint_is_hf: model_config = config.model - converter: HFCheckpointConverter = model_config.default_hf_checkpoint_converter # type: ignore + converter: HFCheckpointConverter = model_config.hf_checkpoint_converter() converter = converter.replaced(reference_checkpoint=config.checkpoint_path, tokenizer=tokenizer) model = converter.load_pretrained( - model_config.model_type, model_config, ref=config.checkpoint_path, dtype=config.trainer.mp.compute_dtype # type: ignore + model_config.model_type, ref=config.checkpoint_path, dtype=config.trainer.mp.compute_dtype # type: ignore ) else: with use_cpu_device(): @@ -265,14 +315,23 @@ def run_eval_harness_main(config: EvalHarnessConfig): config.EvalBatch, axis_resources=compute_axis_mapping, max_examples=max_examples, + max_eval_length=config.eval_harness.max_eval_length, ) logger.info("Finished running LM eval harness") # log the results as json with open("lm_eval_results.json", "w") as f: - json.dump(outputs, f, indent=2) + # also write to stdout + if jax.process_index() == 0: + print(json.dumps(outputs, indent=2), flush=True) + + # also log the results + levanter.tracker.current_tracker().log_artifact("lm_eval_results.json") + + return outputs + if __name__ == "__main__": levanter.config.main(run_eval_harness_main)()