diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index c3c33ac74..4e392cf4d 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -688,8 +688,11 @@ def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_reso def lm_eval_harness(step: StepInfo, force=False): if step.step == 0 and not force: return - - model = inference_mode(step.model, True) + + if step.use_ema: + model = inference_mode(step.ema_model, True) + else: + model = inference_mode(step.model, True) logger.info("Running eval harness...") outputs = _actually_run_eval_harness( config,