diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index cb8c016c0..983750685 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -19,14 +19,12 @@ import levanter.tracker from levanter.data import AsyncDataset, DataLoader -from levanter.eval_harness import LmEvalHarnessConfig from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils import flop_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.utils.logging import save_xla_dumps_to_wandb -from levanter.utils.tree_utils import inference_mode from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -425,45 +423,3 @@ def _tqdm_logging_one_time_setup(): return _did_tqdm_logging_one_time_setup = True tqdm_logging.tqdm_logging.set_log_rate(timedelta(seconds=60)) - - -def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): - from levanter.eval_harness import run_lm_eval_harness - - def lm_eval_harness(step: StepInfo, force=False): - if step.step == 0 and not force: - return # don't run eval on the first step - - model = inference_mode(step.model, True) - outputs = run_lm_eval_harness( - model, - config.task_spec_or_default(), - tokenizer, - EvalBatch, - axis_resources, - max_examples=config.max_examples, - ) - - if jax.process_index() == 0: - with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: - import json - - json.dump(outputs, f) - levanter.tracker.current_tracker().log_artifact( - f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output" - ) - - # also log accuracy statistics etc - metrics_to_log = {} - for task, metrics in outputs["results"].items(): - for metric, value in metrics.items(): - if metric.endswith(",none"): - metric = metric[: -len(",none")] - - if metric != "alias": - # levanter.tracker.log_metrics({f"lm_eval/{task}/{metric}": value}, step=step.step) - metrics_to_log[f"lm_eval/{task}/{metric}"] = value - - levanter.tracker.log_metrics(metrics_to_log, step=step.step) - - return lm_eval_harness diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index bde8f688b..e6b2eb0bc 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -6,6 +6,7 @@ import functools import json import logging +import tempfile import typing from dataclasses import dataclass from functools import cached_property @@ -17,6 +18,7 @@ import haliax +import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer from levanter.models.gpt2 import Gpt2Config from levanter.models.loss import next_token_loss @@ -32,7 +34,7 @@ evaluator = object # tasks = object -from tqdm import tqdm +from tqdm_loggable.auto import tqdm import haliax as hax from haliax.partitioning import round_axis_for_partitioning @@ -41,7 +43,7 @@ from levanter.checkpoint import load_checkpoint from levanter.data import AsyncDataset, DataLoader from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel -from levanter.trainer import TrainerConfig +from levanter.trainer import StepInfo, TrainerConfig from levanter.utils.jax_utils import use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -49,40 +51,6 @@ logger = logging.getLogger(__name__) -# Ok this is a bit complicated to do because it's distributed systems and that's always hard. -# The idea is that we want to pass an LM adaptor to the harness, and then the harness will call the LM adaptor -# with a request, which we'll format, shard, and send to the model. The model will then return the result to the harness -# which will then return the result to the user. - -# As we so often do, we will coordinate execution through JAX itself. - -# Process 0 will: -# - Pass an adaptor to the eval harness -# - The eval harness will call the adaptor with a request -# - When a request comes in, it will call broadcast_one_to_all with a (REQUEST_TYPE, request) to send the request -# - It then invokes the model with the request and returns the result to the eval harness -# - When finished, it will call broadcast_one_to_all with a (FINISHED_TYPE, result) to send the result - -# Process 1..n will: -# - Wait for a (REQUEST_TYPE, request) broadcast -# - if FINISHED_TYPE, break -# - Invoke the model with the request -# - loop - - -class _RequestType: - LOG_LIKELIHOOD = 0 - GENERATE_UNTIL = 1 - LOG_LIKELIHOOD_ROLLING = 2 - FINISHED = 3 - - -@functools.partial(jax.jit, static_argnums=(0, 3)) -def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): - tokens = hax.named(tokens, Pos) - return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) - - class EvalDataset(AsyncDataset[LmExample]): def __init__(self, Pos, tokenizer, examples: list[Instance]): super().__init__() @@ -211,6 +179,12 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() +@functools.partial(jax.jit, static_argnums=(0, 3)) +def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): + tokens = hax.named(tokens, Pos) + return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) + + def run_lm_eval_harness( model, task_spec: list[str], @@ -219,11 +193,12 @@ def run_lm_eval_harness( axis_resources, max_examples: int | None = None, max_eval_length: int | None = None, + log_samples: bool = False, ) -> dict: EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) tasks_to_run = tasks.get_task_dict(task_spec) - outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples) + outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=log_samples) return outputs @@ -233,6 +208,7 @@ class LmEvalHarnessConfig: task_spec: list[str] | None = None max_examples: int | None = None max_eval_length: int | None = None + log_samples: bool = False def task_spec_or_default(self): return self.task_spec or [ @@ -242,9 +218,9 @@ def task_spec_or_default(self): # "winogrande", # "mathqa", # "pubmedqa", - # "boolq", + "boolq", # "cb", - # "copa", + "copa", # "multirc", # "record", # "wic", @@ -316,6 +292,7 @@ def run_eval_harness_main(config: EvalHarnessConfig): axis_resources=compute_axis_mapping, max_examples=max_examples, max_eval_length=config.eval_harness.max_eval_length, + log_samples=config.eval_harness.log_samples, ) logger.info("Finished running LM eval harness") @@ -329,9 +306,57 @@ def run_eval_harness_main(config: EvalHarnessConfig): # also log the results levanter.tracker.current_tracker().log_artifact("lm_eval_results.json") + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) return outputs +def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter.tracker.Tracker] = None): + if tracker is None: + tracker = levanter.tracker.current_tracker() + + to_log = {} + for task_name, task_results in report["results"].items(): + for metric_name, metric_value in task_results.items(): + if metric_name.ends_with(",none"): + metric_name = metric_name[:-5] + + if isinstance(metric_value, float | int): + to_log[f"{prefix}/{task_name}/{metric_name}"] = metric_value + + tracker.log(to_log, step=None) + + +def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): + def lm_eval_harness(step: StepInfo, force=False): + if step.step == 0 and not force: + return # don't run eval on the first step + + model = inference_mode(step.model, True) + outputs = run_lm_eval_harness( + model, + config.task_spec_or_default(), + tokenizer, + EvalBatch, + axis_resources, + max_examples=config.max_examples, + max_eval_length=config.max_eval_length, + log_samples=config.log_samples, + ) + + if jax.process_index() == 0: + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: + import json + + json.dump(outputs, f) + levanter.tracker.current_tracker().log_artifact( + f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output" + ) + + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + + return lm_eval_harness + + if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 9c598b63c..cf327956b 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -13,6 +13,8 @@ from haliax.partitioning import named_jit, round_axis_for_partitioning import levanter +import levanter.eval +import levanter.eval_harness from levanter import callbacks from levanter.checkpoint import EpochCheckpointer, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback @@ -253,7 +255,7 @@ def main(config: TrainLmConfig): if config.eval_harness is not None: eval_harness = config.eval_harness trainer.add_hook( - callbacks.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), + levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), every=config.eval_harness_steps, )