diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py
index cb8c016c0..983750685 100644
--- a/src/levanter/callbacks.py
+++ b/src/levanter/callbacks.py
@@ -19,14 +19,12 @@
 
 import levanter.tracker
 from levanter.data import AsyncDataset, DataLoader
-from levanter.eval_harness import LmEvalHarnessConfig
 from levanter.tracker.helpers import log_optimizer_hyperparams
 from levanter.tracker.wandb import WandbConfig
 from levanter.trainer import StepInfo
 from levanter.utils import flop_utils
 from levanter.utils.jax_utils import barrier_sync, jnp_to_python
 from levanter.utils.logging import save_xla_dumps_to_wandb
-from levanter.utils.tree_utils import inference_mode
 from levanter.visualization import compute_and_visualize_log_probs as viz_probs
 
 
@@ -425,45 +423,3 @@ def _tqdm_logging_one_time_setup():
         return
     _did_tqdm_logging_one_time_setup = True
     tqdm_logging.tqdm_logging.set_log_rate(timedelta(seconds=60))
-
-
-def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources):
-    from levanter.eval_harness import run_lm_eval_harness
-
-    def lm_eval_harness(step: StepInfo, force=False):
-        if step.step == 0 and not force:
-            return  # don't run eval on the first step
-
-        model = inference_mode(step.model, True)
-        outputs = run_lm_eval_harness(
-            model,
-            config.task_spec_or_default(),
-            tokenizer,
-            EvalBatch,
-            axis_resources,
-            max_examples=config.max_examples,
-        )
-
-        if jax.process_index() == 0:
-            with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
-                import json
-
-                json.dump(outputs, f)
-                levanter.tracker.current_tracker().log_artifact(
-                    f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output"
-                )
-
-            # also log accuracy statistics etc
-            metrics_to_log = {}
-            for task, metrics in outputs["results"].items():
-                for metric, value in metrics.items():
-                    if metric.endswith(",none"):
-                        metric = metric[: -len(",none")]
-
-                    if metric != "alias":
-                        # levanter.tracker.log_metrics({f"lm_eval/{task}/{metric}": value}, step=step.step)
-                        metrics_to_log[f"lm_eval/{task}/{metric}"] = value
-
-            levanter.tracker.log_metrics(metrics_to_log, step=step.step)
-
-    return lm_eval_harness
diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py
index bde8f688b..e6b2eb0bc 100644
--- a/src/levanter/eval_harness.py
+++ b/src/levanter/eval_harness.py
@@ -6,6 +6,7 @@
 import functools
 import json
 import logging
+import tempfile
 import typing
 from dataclasses import dataclass
 from functools import cached_property
@@ -17,6 +18,7 @@
 
 import haliax
 
+import levanter.tracker
 from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
 from levanter.models.gpt2 import Gpt2Config
 from levanter.models.loss import next_token_loss
@@ -32,7 +34,7 @@
     evaluator = object
     # tasks = object
 
-from tqdm import tqdm
+from tqdm_loggable.auto import tqdm
 
 import haliax as hax
 from haliax.partitioning import round_axis_for_partitioning
@@ -41,7 +43,7 @@
 from levanter.checkpoint import load_checkpoint
 from levanter.data import AsyncDataset, DataLoader
 from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
-from levanter.trainer import TrainerConfig
+from levanter.trainer import StepInfo, TrainerConfig
 from levanter.utils.jax_utils import use_cpu_device
 from levanter.utils.tree_utils import inference_mode
 
@@ -49,40 +51,6 @@
 logger = logging.getLogger(__name__)
 
 
-# Ok this is a bit complicated to do because it's distributed systems and that's always hard.
-# The idea is that we want to pass an LM adaptor to the harness, and then the harness will call the LM adaptor
-# with a request, which we'll format, shard, and send to the model. The model will then return the result to the harness
-# which will then return the result to the user.
-
-# As we so often do, we will coordinate execution through JAX itself.
-
-# Process 0 will:
-# - Pass an adaptor to the eval harness
-# - The eval harness will call the adaptor with a request
-# - When a request comes in, it will call broadcast_one_to_all with a (REQUEST_TYPE, request) to send the request
-# - It then invokes the model with the request and returns the result to the eval harness
-# - When finished, it will call broadcast_one_to_all with a (FINISHED_TYPE, result) to send the result
-
-# Process 1..n will:
-# - Wait for a (REQUEST_TYPE, request) broadcast
-# - if FINISHED_TYPE, break
-# - Invoke the model with the request
-# - loop
-
-
-class _RequestType:
-    LOG_LIKELIHOOD = 0
-    GENERATE_UNTIL = 1
-    LOG_LIKELIHOOD_ROLLING = 2
-    FINISHED = 3
-
-
-@functools.partial(jax.jit, static_argnums=(0, 3))
-def _jit_create_example(Pos, tokens, prompt_len, pad_token_id):
-    tokens = hax.named(tokens, Pos)
-    return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id)
-
-
 class EvalDataset(AsyncDataset[LmExample]):
     def __init__(self, Pos, tokenizer, examples: list[Instance]):
         super().__init__()
@@ -211,6 +179,12 @@ def generate_until(self, requests) -> List[str]:
         raise NotImplementedError()
 
 
+@functools.partial(jax.jit, static_argnums=(0, 3))
+def _jit_create_example(Pos, tokens, prompt_len, pad_token_id):
+    tokens = hax.named(tokens, Pos)
+    return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id)
+
+
 def run_lm_eval_harness(
     model,
     task_spec: list[str],
@@ -219,11 +193,12 @@ def run_lm_eval_harness(
     axis_resources,
     max_examples: int | None = None,
     max_eval_length: int | None = None,
+    log_samples: bool = False,
 ) -> dict:
     EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length)
     harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer)
     tasks_to_run = tasks.get_task_dict(task_spec)
-    outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples)
+    outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=log_samples)
 
     return outputs
 
@@ -233,6 +208,7 @@ class LmEvalHarnessConfig:
     task_spec: list[str] | None = None
     max_examples: int | None = None
     max_eval_length: int | None = None
+    log_samples: bool = False
 
     def task_spec_or_default(self):
         return self.task_spec or [
@@ -242,9 +218,9 @@ def task_spec_or_default(self):
             # "winogrande",
             # "mathqa",
             # "pubmedqa",
-            # "boolq",
+            "boolq",
             # "cb",
-            # "copa",
+            "copa",
             # "multirc",
             # "record",
             # "wic",
@@ -316,6 +292,7 @@ def run_eval_harness_main(config: EvalHarnessConfig):
             axis_resources=compute_axis_mapping,
             max_examples=max_examples,
             max_eval_length=config.eval_harness.max_eval_length,
+            log_samples=config.eval_harness.log_samples,
         )
 
         logger.info("Finished running LM eval harness")
@@ -329,9 +306,57 @@ def run_eval_harness_main(config: EvalHarnessConfig):
 
         # also log the results
         levanter.tracker.current_tracker().log_artifact("lm_eval_results.json")
+        log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())
 
         return outputs
 
 
+def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter.tracker.Tracker] = None):
+    if tracker is None:
+        tracker = levanter.tracker.current_tracker()
+
+    to_log = {}
+    for task_name, task_results in report["results"].items():
+        for metric_name, metric_value in task_results.items():
+            if metric_name.ends_with(",none"):
+                metric_name = metric_name[:-5]
+
+            if isinstance(metric_value, float | int):
+                to_log[f"{prefix}/{task_name}/{metric_name}"] = metric_value
+
+    tracker.log(to_log, step=None)
+
+
+def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources):
+    def lm_eval_harness(step: StepInfo, force=False):
+        if step.step == 0 and not force:
+            return  # don't run eval on the first step
+
+        model = inference_mode(step.model, True)
+        outputs = run_lm_eval_harness(
+            model,
+            config.task_spec_or_default(),
+            tokenizer,
+            EvalBatch,
+            axis_resources,
+            max_examples=config.max_examples,
+            max_eval_length=config.max_eval_length,
+            log_samples=config.log_samples,
+        )
+
+        if jax.process_index() == 0:
+            with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
+                import json
+
+                json.dump(outputs, f)
+                levanter.tracker.current_tracker().log_artifact(
+                    f.name, name=f"lm_eval_output.{step.step}", type="lm_eval_output"
+                )
+
+            log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())
+
+    return lm_eval_harness
+
+
 if __name__ == "__main__":
     levanter.config.main(run_eval_harness_main)()
diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py
index 9c598b63c..cf327956b 100644
--- a/src/levanter/main/train_lm.py
+++ b/src/levanter/main/train_lm.py
@@ -13,6 +13,8 @@
 from haliax.partitioning import named_jit, round_axis_for_partitioning
 
 import levanter
+import levanter.eval
+import levanter.eval_harness
 from levanter import callbacks
 from levanter.checkpoint import EpochCheckpointer, load_checkpoint
 from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
@@ -253,7 +255,7 @@ def main(config: TrainLmConfig):
         if config.eval_harness is not None:
             eval_harness = config.eval_harness
             trainer.add_hook(
-                callbacks.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping),
+                levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping),
                 every=config.eval_harness_steps,
             )