Skip to content

Commit

Permalink
eval_harness is about there
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 24, 2024
1 parent 4ecc630 commit bcfc225
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 71 deletions.
27 changes: 27 additions & 0 deletions config/harness/eval_marin_dclm_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
eval_harness:
task_spec: ["hellaswag"]
# max_examples: 9984 # this is the max that ends up being divisible by 512 after expansion
max_examples: 8 # this is the max that ends up being divisible by 512 after expansion
max_eval_length: 128
#tokenizer: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930
#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/hf/step-715001/
#tokenizer: gs://levanter-checkpoints/marin/olmoish7b_v4_1024_0627/dlwh_7b0627/step-510000/
#tokenizer: "EleutherAI/gpt-neox-20b"
tokenizer: meta-llama/Meta-Llama-3-8B
model:
type: llama
#checkpoint_path: gs://marin-us-central2/checkpoints/dclm_baseline_1b_1x_replication_nov12_3404462497seed-b68241/hf/step-54930
checkpoint_path: meta-llama/Meta-Llama-3-8B
checkpoint_is_hf: true
trainer:
mp: f32
profiler: true

per_device_parallelism: -1
train_batch_size: 512

tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
ray:
auto_start_cluster: false
201 changes: 130 additions & 71 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
# https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py
# https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/TPU_cluster.py#L6
import dataclasses
import functools
import json
import logging
import typing
import warnings
from dataclasses import dataclass
from functools import cached_property
from typing import List, Optional, Tuple
from typing import List, Optional, Sequence, Tuple

import equinox as eqx
import jax
import jax.numpy as jnp
import transformers

from levanter.compat.hf_checkpoints import HFCheckpointConverter
import haliax

from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
from levanter.models.gpt2 import Gpt2Config
from levanter.models.loss import next_token_loss


try:
Expand All @@ -33,15 +35,14 @@
from tqdm import tqdm

import haliax as hax
from haliax.nn import cross_entropy_loss
from haliax.partitioning import round_axis_for_partitioning

import levanter.config
from levanter.checkpoint import load_checkpoint
from levanter.data import batched
from levanter.data import AsyncDataset, DataLoader
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
from levanter.trainer import TrainerConfig
from levanter.utils.jax_utils import stack_tree, use_cpu_device
from levanter.utils.jax_utils import use_cpu_device
from levanter.utils.tree_utils import inference_mode


Expand Down Expand Up @@ -76,103 +77,151 @@ class _RequestType:
FINISHED = 3


@functools.partial(jax.jit, static_argnums=(0, 3))
def _jit_create_example(Pos, tokens, prompt_len, pad_token_id):
tokens = hax.named(tokens, Pos)
return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id)


class EvalDataset(AsyncDataset[LmExample]):
def __init__(self, Pos, tokenizer, examples: list[Instance]):
super().__init__()
self.examples = examples
self.Pos = Pos
self.tokenizer = tokenizer

async def async_len(self) -> int:
return len(self.examples)

async def final_length_is_known(self) -> bool:
return True

def is_finite(self) -> bool:
return True

async def current_len(self) -> Optional[int]:
return len(self.examples)

async def get_batch(self, indices: Sequence[int]) -> List[LmExample]:
out = []
pad_token_id = self.tokenizer.pad_token_id

reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices]

for context, completion in reqs:
whole_enc = self.tokenizer(context + completion)
context_enc = self.tokenizer(context)

context_enc_len = len(context_enc["input_ids"])

tokens, length = self._truncate_or_pad(whole_enc, context_enc_len)
example = _jit_create_example(self.Pos, tokens, length, pad_token_id)

out.append(example)

return out

def _truncate_or_pad(self, encoded, prompt_length):
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

ex_pad = self.tokenizer.pad(
encoded,
padding="max_length",
max_length=self.Pos.size,
return_tensors="np",
)

truncated = ex_pad["input_ids"][-self.Pos.size :]
# if we truncated the prompt, we need to adjust the prompt length
if len(truncated) < len(encoded):
prompt_length -= len(encoded) - len(truncated)
if prompt_length < 0:
prompt_length = 0
logger.warning("Prompt length is negative after truncation. Setting to 0.")

return truncated, prompt_length


class LevanterHarnessLM(LM):
def __init__(self, EvalBatch: hax.Axis, model: LmHeadModel, axis_resources, tokenizer):
def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, axis_resources, tokenizer):
super().__init__()
self.EvalBatch = EvalBatch
self.EvalPos = EvalPos
self.model = model
self.axis_resources = axis_resources
self.tokenizer = tokenizer

def _eval_loglikelihood(model: LmHeadModel, example: LmExample):
logits = model(example.tokens)
logits = model(example.tokens, attn_mask=example.attn_mask)
logits = logits.astype(jnp.float32)
Pos = logits.resolve_axis(self.EvalPos.name)

loss = next_token_loss(
Pos=Pos,
Vocab=model.Vocab,
logits=logits,
true_ids=example.tokens,
loss_mask=example.loss_mask,
reduction=hax.sum,
reduction_axis=Pos,
)

targets = hax.roll(example.tokens, -1, axis=model.Pos.name)
target_y = hax.nn.one_hot(targets, model.Vocab, dtype=logits.dtype)
loss = cross_entropy_loss(logits, model.Vocab, target_y, where=example.loss_mask, reduction_axis=model.Pos)
# to tell if we got the right answer, we want to check that argmax(logits) == tokens wherever loss_mask is 1
not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool)
pred_targets = hax.argmax(logits, axis=model.Vocab)
correct = hax.all(hax.equal(pred_targets, targets) | hax.logical_not(example.loss_mask), axis=model.Pos)
targets = hax.roll(example.tokens, -1, axis=Pos)
freebie = hax.logical_not(example.loss_mask * not_last_loss_mask)
correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos)

return loss, correct
return -loss, correct

# no sharded outputs
self._jit_loglikelihood = hax.named_jit(
_eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={}
)

def _stack_batch(self, examples):
return stack_tree(self.EvalBatch, examples, pad_to_batch_size=True)

def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
"""
Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
Args:
requests:
Returns:
"""
dataset = EvalDataset(self.EvalPos, self.tokenizer, requests)

contexts = self.tokenizer([req.args[0] for req in requests])["input_ids"]
completions = self.tokenizer([req.args[1] for req in requests])["input_ids"]

examples: list[LmExample] = []

@hax.named_jit
def _jit_create_example(tokens, prompt_len):
tokens = hax.named(tokens, self.model.Pos)
return LmExample.from_prompt_and_completion(
self.model.Pos, tokens, prompt_len, ignore_id=self.tokenizer.pad_token_id
)
mesh = haliax.partitioning._get_mesh()

# TODO: offload this to an evalbatchloader
for context, completion in zip(tqdm(contexts, desc="Creating examples"), completions):
tokens, length = self._truncate_or_pad(context, completion)
tokens = jnp.array(tokens)
length = jnp.array(length)
example = _jit_create_example(tokens, length)
examples.append(example)
loader = DataLoader(
self.EvalBatch, dataset, max_buffered_batches=1024, mesh=mesh, axis_resources=self.axis_resources
)

result: list[tuple[float, bool]] = []
for batch in batched(tqdm(examples, desc="examples", leave=False), self.EvalBatch.size):
logger.info("Processing batch")
batch_example = self._stack_batch(batch)
# batch_example = jax.device_put(batch_example, jax.local_devices()[0])
out_lls, out_correct = self._jit_loglikelihood(self.model, batch_example)
for batch in tqdm(loader, desc="Loglikelihood", unit="ba"):
out_lls, out_correct = self._jit_loglikelihood(self.model, batch)
result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array))

# skip padding
result = result[: len(examples)]
result = result[: len(requests)]

return result

def _truncate_or_pad(self, context, completion):
max_len = self.model.Pos.size
if len(completion) > max_len:
warnings.warn(f"Completion is longer than max length {max_len}. Truncating.")
completion = completion[:max_len]
pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id

if len(context) + len(completion) > max_len:
context = context[-(max_len - len(completion)) :]
else:
# right pad with padding token
context = context + [pad_token_id] * (max_len - len(context) - len(completion))

return jnp.array(context + completion), len(context)

def loglikelihood_rolling(self, requests) -> List[Tuple[float]]:
raise NotImplementedError()

def generate_until(self, requests) -> List[str]:
raise NotImplementedError()


def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_resources, max_examples=None) -> dict:
harness = LevanterHarnessLM(EvalBatch, model, axis_resources, tokenizer)
def run_lm_eval_harness(
model,
task_spec: list[str],
tokenizer,
EvalBatch,
axis_resources,
max_examples: int | None = None,
max_eval_length: int | None = None,
) -> dict:
EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length)
harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer)
tasks_to_run = tasks.get_task_dict(task_spec)
outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples)

Expand All @@ -181,13 +230,14 @@ def run_lm_eval_harness(model, task_spec: list[str], tokenizer, EvalBatch, axis_

@dataclass(frozen=True)
class LmEvalHarnessConfig:
task_spec: Optional[list[str]] = None
max_examples: Optional[int] = None
task_spec: list[str] | None = None
max_examples: int | None = None
max_eval_length: int | None = None

def task_spec_or_default(self):
return self.task_spec or [
# "lambada",
# "piqa",
"piqa",
"hellaswag",
# "winogrande",
# "mathqa",
Expand Down Expand Up @@ -218,7 +268,7 @@ def EvalBatch(self):

@cached_property
def the_tokenizer(self):
return transformers.AutoTokenizer.from_pretrained(self.tokenizer)
return load_tokenizer(self.tokenizer)


def run_eval_harness_main(config: EvalHarnessConfig):
Expand All @@ -244,10 +294,10 @@ def run_eval_harness_main(config: EvalHarnessConfig):
# initialize the model
if config.checkpoint_is_hf:
model_config = config.model
converter: HFCheckpointConverter = model_config.default_hf_checkpoint_converter # type: ignore
converter: HFCheckpointConverter = model_config.hf_checkpoint_converter()
converter = converter.replaced(reference_checkpoint=config.checkpoint_path, tokenizer=tokenizer)
model = converter.load_pretrained(
model_config.model_type, model_config, ref=config.checkpoint_path, dtype=config.trainer.mp.compute_dtype # type: ignore
model_config.model_type, ref=config.checkpoint_path, dtype=config.trainer.mp.compute_dtype # type: ignore
)
else:
with use_cpu_device():
Expand All @@ -265,14 +315,23 @@ def run_eval_harness_main(config: EvalHarnessConfig):
config.EvalBatch,
axis_resources=compute_axis_mapping,
max_examples=max_examples,
max_eval_length=config.eval_harness.max_eval_length,
)

logger.info("Finished running LM eval harness")
# log the results as json
with open("lm_eval_results.json", "w") as f:

json.dump(outputs, f, indent=2)

# also write to stdout
if jax.process_index() == 0:
print(json.dumps(outputs, indent=2), flush=True)

# also log the results
levanter.tracker.current_tracker().log_artifact("lm_eval_results.json")

return outputs


if __name__ == "__main__":
levanter.config.main(run_eval_harness_main)()

0 comments on commit bcfc225

Please sign in to comment.