diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 834e06bd7..b9f4381fe 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -38,6 +38,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer from levanter.models.gpt2 import Gpt2Config from levanter.models.loss import next_token_loss +from levanter.utils.hf_utils import HfTokenizer try: @@ -52,7 +53,7 @@ from tqdm_loggable.auto import tqdm import haliax as hax -from haliax.partitioning import round_axis_for_partitioning +from haliax.partitioning import ResourceMapping, round_axis_for_partitioning import levanter.config from levanter.checkpoint import load_checkpoint @@ -361,23 +362,20 @@ def run_lm_eval_harness( def _actually_run_eval_harness( - config: LmEvalHarnessConfig, model: LM, tasks_to_run: dict, tokenizer, EvalBatch, axis_resources + config: LmEvalHarnessConfig, + model: LmHeadModel, + tasks_to_run: dict, + tokenizer: HfTokenizer, + EvalBatch: haliax.Axis, + axis_resources: ResourceMapping, ): """ Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used by the main function and the callback function. - Args: - config: - model: - tasks_to_run: - tokenizer: The tokenizer to use - EvalBatch: The batch axis for compute - axis_resources: axis mapping for compute - Returns: - - + The outputs of the LM Eval Harness with the following extra keys: + - "averages": A dictionary with macro and micro averages for all metrics. """ max_examples = config.max_examples