Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 26, 2024
1 parent bcfc225 commit e0ef6f8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 84 deletions.
44 changes: 0 additions & 44 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@

import levanter.tracker
from levanter.data import AsyncDataset, DataLoader
from levanter.eval_harness import LmEvalHarnessConfig
from levanter.tracker.helpers import log_optimizer_hyperparams
from levanter.tracker.wandb import WandbConfig
from levanter.trainer import StepInfo
from levanter.utils import flop_utils
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.utils.logging import save_xla_dumps_to_wandb
from levanter.utils.tree_utils import inference_mode
from levanter.visualization import compute_and_visualize_log_probs as viz_probs


Expand Down Expand Up @@ -425,45 +423,3 @@ def _tqdm_logging_one_time_setup():
return
_did_tqdm_logging_one_time_setup = True
tqdm_logging.tqdm_logging.set_log_rate(timedelta(seconds=60))


def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources):
from levanter.eval_harness import run_lm_eval_harness

def lm_eval_harness(step: StepInfo, force=False):
if step.step == 0 and not force:
return # don't run eval on the first step

model = inference_mode(step.model, True)
outputs = run_lm_eval_harness(
model,
config.task_spec_or_default(),
tokenizer,
EvalBatch,
axis_resources,
max_examples=config.max_examples,
)

if jax.process_index() == 0:
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
import json

json.dump(outputs, f)
levanter.tracker.current_tracker().log_artifact(
f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output"
)

# also log accuracy statistics etc
metrics_to_log = {}
for task, metrics in outputs["results"].items():
for metric, value in metrics.items():
if metric.endswith(",none"):
metric = metric[: -len(",none")]

if metric != "alias":
# levanter.tracker.log_metrics({f"lm_eval/{task}/{metric}": value}, step=step.step)
metrics_to_log[f"lm_eval/{task}/{metric}"] = value

levanter.tracker.log_metrics(metrics_to_log, step=step.step)

return lm_eval_harness
103 changes: 64 additions & 39 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import json
import logging
import tempfile
import typing
from dataclasses import dataclass
from functools import cached_property
Expand All @@ -17,6 +18,7 @@

import haliax

import levanter.tracker
from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
from levanter.models.gpt2 import Gpt2Config
from levanter.models.loss import next_token_loss
Expand All @@ -32,7 +34,7 @@
evaluator = object
# tasks = object

from tqdm import tqdm
from tqdm_loggable.auto import tqdm

import haliax as hax
from haliax.partitioning import round_axis_for_partitioning
Expand All @@ -41,48 +43,14 @@
from levanter.checkpoint import load_checkpoint
from levanter.data import AsyncDataset, DataLoader
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
from levanter.trainer import TrainerConfig
from levanter.trainer import StepInfo, TrainerConfig
from levanter.utils.jax_utils import use_cpu_device
from levanter.utils.tree_utils import inference_mode


logger = logging.getLogger(__name__)


# Ok this is a bit complicated to do because it's distributed systems and that's always hard.
# The idea is that we want to pass an LM adaptor to the harness, and then the harness will call the LM adaptor
# with a request, which we'll format, shard, and send to the model. The model will then return the result to the harness
# which will then return the result to the user.

# As we so often do, we will coordinate execution through JAX itself.

# Process 0 will:
# - Pass an adaptor to the eval harness
# - The eval harness will call the adaptor with a request
# - When a request comes in, it will call broadcast_one_to_all with a (REQUEST_TYPE, request) to send the request
# - It then invokes the model with the request and returns the result to the eval harness
# - When finished, it will call broadcast_one_to_all with a (FINISHED_TYPE, result) to send the result

# Process 1..n will:
# - Wait for a (REQUEST_TYPE, request) broadcast
# - if FINISHED_TYPE, break
# - Invoke the model with the request
# - loop


class _RequestType:
LOG_LIKELIHOOD = 0
GENERATE_UNTIL = 1
LOG_LIKELIHOOD_ROLLING = 2
FINISHED = 3


@functools.partial(jax.jit, static_argnums=(0, 3))
def _jit_create_example(Pos, tokens, prompt_len, pad_token_id):
tokens = hax.named(tokens, Pos)
return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id)


class EvalDataset(AsyncDataset[LmExample]):
def __init__(self, Pos, tokenizer, examples: list[Instance]):
super().__init__()
Expand Down Expand Up @@ -211,6 +179,12 @@ def generate_until(self, requests) -> List[str]:
raise NotImplementedError()


@functools.partial(jax.jit, static_argnums=(0, 3))
def _jit_create_example(Pos, tokens, prompt_len, pad_token_id):
tokens = hax.named(tokens, Pos)
return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id)


def run_lm_eval_harness(
model,
task_spec: list[str],
Expand All @@ -219,11 +193,12 @@ def run_lm_eval_harness(
axis_resources,
max_examples: int | None = None,
max_eval_length: int | None = None,
log_samples: bool = False,
) -> dict:
EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length)
harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer)
tasks_to_run = tasks.get_task_dict(task_spec)
outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples)
outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=log_samples)

return outputs

Expand All @@ -233,6 +208,7 @@ class LmEvalHarnessConfig:
task_spec: list[str] | None = None
max_examples: int | None = None
max_eval_length: int | None = None
log_samples: bool = False

def task_spec_or_default(self):
return self.task_spec or [
Expand All @@ -242,9 +218,9 @@ def task_spec_or_default(self):
# "winogrande",
# "mathqa",
# "pubmedqa",
# "boolq",
"boolq",
# "cb",
# "copa",
"copa",
# "multirc",
# "record",
# "wic",
Expand Down Expand Up @@ -316,6 +292,7 @@ def run_eval_harness_main(config: EvalHarnessConfig):
axis_resources=compute_axis_mapping,
max_examples=max_examples,
max_eval_length=config.eval_harness.max_eval_length,
log_samples=config.eval_harness.log_samples,
)

logger.info("Finished running LM eval harness")
Expand All @@ -329,9 +306,57 @@ def run_eval_harness_main(config: EvalHarnessConfig):

# also log the results
levanter.tracker.current_tracker().log_artifact("lm_eval_results.json")
log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())

return outputs


def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter.tracker.Tracker] = None):
if tracker is None:
tracker = levanter.tracker.current_tracker()

to_log = {}
for task_name, task_results in report["results"].items():
for metric_name, metric_value in task_results.items():
if metric_name.ends_with(",none"):
metric_name = metric_name[:-5]

if isinstance(metric_value, float | int):
to_log[f"{prefix}/{task_name}/{metric_name}"] = metric_value

tracker.log(to_log, step=None)


def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources):
def lm_eval_harness(step: StepInfo, force=False):
if step.step == 0 and not force:
return # don't run eval on the first step

model = inference_mode(step.model, True)
outputs = run_lm_eval_harness(
model,
config.task_spec_or_default(),
tokenizer,
EvalBatch,
axis_resources,
max_examples=config.max_examples,
max_eval_length=config.max_eval_length,
log_samples=config.log_samples,
)

if jax.process_index() == 0:
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
import json

json.dump(outputs, f)
levanter.tracker.current_tracker().log_artifact(
f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output"
)

log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())

return lm_eval_harness


if __name__ == "__main__":
levanter.config.main(run_eval_harness_main)()
4 changes: 3 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from haliax.partitioning import named_jit, round_axis_for_partitioning

import levanter
import levanter.eval
import levanter.eval_harness
from levanter import callbacks
from levanter.checkpoint import EpochCheckpointer, load_checkpoint
from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
Expand Down Expand Up @@ -253,7 +255,7 @@ def main(config: TrainLmConfig):
if config.eval_harness is not None:
eval_harness = config.eval_harness
trainer.add_hook(
callbacks.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping),
levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping),
every=config.eval_harness_steps,
)

Expand Down

0 comments on commit e0ef6f8

Please sign in to comment.