diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml
index 3a21732a7..291213d75 100644
--- a/config/gpt2_small_fast_pile.yaml
+++ b/config/gpt2_small_fast_pile.yaml
@@ -1,4 +1,4 @@
-data: !include data/pile_source_old.yaml
+data: !include data/pile_mixture.yaml
 model:
   type: gpt2
   hidden_dim: 768
diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml
index d71e1267e..93675366d 100644
--- a/config/gpt2_small_fast_supervised.yaml
+++ b/config/gpt2_small_fast_supervised.yaml
@@ -15,6 +15,7 @@ data:
 supervised_data:
   validation_urls:
     - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz"
+    - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz"
   cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/"
   input_field: "input"
   output_field: "output"
diff --git a/config/llama3_small_fast.yaml b/config/llama3_small_fast.yaml
new file mode 100644
index 000000000..df1df9f96
--- /dev/null
+++ b/config/llama3_small_fast.yaml
@@ -0,0 +1,32 @@
+data:
+  train_urls:
+    - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
+  validation_urls:
+    - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
+  cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/"
+  tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
+model:
+  type: llama
+  hidden_dim: 768
+  intermediate_dim: 2048
+  num_heads: 12
+  num_kv_heads: 12
+  num_layers: 12
+  seq_len: 1024
+  gradient_checkpointing: true
+trainer:
+  tracker:
+    - type: wandb
+      project: "levanter"
+      tags: [ "openwebtext", "llama", "itest"]
+
+  mp: p=f32,c=bfloat16
+  model_axis_size: 1
+  per_device_parallelism: -1
+
+  train_batch_size: 256
+  num_train_steps: 20000
+optimizer:
+  learning_rate: 1E-3
+  weight_decay: 0.1
+  warmup: 0.01
diff --git a/config/llama_7b_with_dclm.yaml b/config/llama_7b_with_dclm.yaml
index 980e64e41..11a182f09 100644
--- a/config/llama_7b_with_dclm.yaml
+++ b/config/llama_7b_with_dclm.yaml
@@ -17,7 +17,7 @@ trainer:
 
   mp: p=f32,c=bfloat16
   train_batch_size: 2048
-  num_train_steps: 70000  # 280B / 4M
+  num_train_steps: 480000  # 2T / 4M
   steps_per_eval: 1000
   tensor_parallel_axes: ["mlp", "heads"]
   fsdp_axis: "embed"
diff --git a/infra/cluster/job-cluster.yaml b/infra/cluster/job-cluster.yaml
index cf8703d54..cff7d4884 100644
--- a/infra/cluster/job-cluster.yaml
+++ b/infra/cluster/job-cluster.yaml
@@ -14,8 +14,8 @@ cluster_name: levanter-cluster
 # Configure GCP
 provider:
   type: gcp
-  region: us-central2
-  availability_zone: us-central2-b
+  region: us-west4
+  availability_zone: us-west4-a
   project_id: hai-gcp-models
 
 # Maximum Workers (excluding Head Node)
@@ -126,6 +126,45 @@ available_node_types:
       schedulingConfig:
         preemptible: true
 
+  tpu_slice_v5e_16:
+    min_workers: 0
+    max_workers: 1024
+    resources: { "CPU": 120, "TPU": 4 }
+
+    node_config:
+      acceleratorType: v5litepod-16
+      runtimeVersion: tpu-ubuntu2204-base
+
+      # [IMPORTANT] Configure all TPU Workers to be Preemptible!
+      schedulingConfig:
+        preemptible: true
+
+  tpu_slice_v5e_64:
+    min_workers: 0
+    max_workers: 1024
+    resources: { "CPU": 120, "TPU": 4 }
+
+    node_config:
+      acceleratorType: v5litepod-64
+      runtimeVersion: tpu-ubuntu2204-base
+
+      # [IMPORTANT] Configure all TPU Workers to be Preemptible!
+      schedulingConfig:
+        preemptible: true
+
+  tpu_slice_v5e_256:
+    min_workers: 0
+    max_workers: 1024
+    resources: { "CPU": 120, "TPU": 4 }
+
+    node_config:
+      acceleratorType: v5litepod-256
+      runtimeVersion: tpu-ubuntu2204-base
+
+      # [IMPORTANT] Configure all TPU Workers to be Preemptible!
+      schedulingConfig:
+        preemptible: true
+
 docker:
     image: "ghcr.io/stanford-crfm/levanter-cluster:latest"
     container_name: "ray_docker"
@@ -140,7 +179,7 @@ docker:
         - -v "/var/run/docker.sock:/var/run/docker.sock"
 
 initialization_commands:
-  - yes | gcloud auth configure-docker us-central2-docker.pkg.dev
+  - yes | gcloud auth configure-docker us-west4-docker.pkg.dev
   - "export TPU_WORKER_ID=$(curl -H 'Metadata-Flavor: Google' http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number) || true"
   - which docker || (curl -fsSL https://get.docker.com -o get-docker.sh; sudo sh get-docker.sh; sudo usermod -aG docker $USER; sudo systemctl restart docker -f)
   # always run this because ray doesn't run with sudo
diff --git a/infra/launch_on_ray.py b/infra/launch_on_ray.py
index fa5e81f27..90f2c586a 100755
--- a/infra/launch_on_ray.py
+++ b/infra/launch_on_ray.py
@@ -27,7 +27,7 @@ def main():
     cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"])
     cli.add_arg(parser, config, ["--tpu_type"], required=True)
     # TODO: bring node_count to Ray
-    # cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
+    cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
     cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true")
     cli.add_arg(parser, config, ["--retries"], default=10, type=int)
     cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str)
@@ -122,6 +122,7 @@ def main():
         env=env,
         name="levanter",
         retries=retries,
+        node_count=args.node_count,
     )
 
     address = args.address or os.getenv("RAY_ADDRESS")
diff --git a/pyproject.toml b/pyproject.toml
index 19fb077bf..0831605cb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,7 +32,7 @@ dependencies = [
     "draccus>=0.8.0",
     "pyarrow>=11.0.0",
     "zstandard>=0.20.0",
-    "datasets>=2.18,<4.0",
+    "datasets>=3.1.0,<4.0",
     "gcsfs>=2024.2,<2024.10",
     "braceexpand>=0.1.7",
     "jmp>=0.0.3",
diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py
index 333ddf768..cca3156b8 100644
--- a/src/levanter/data/sharded_datasource.py
+++ b/src/levanter/data/sharded_datasource.py
@@ -262,7 +262,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
         dataset = self._load_dataset()
         if isinstance(dataset, datasets.IterableDataset) and shard_name != "data":
             # ex_iterable has a key that gets discarded typically
-            shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards))
+            shard = map(
+                lambda t: t[1],
+                dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards),
+            )
         else:
             shard = dataset
 
diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py
index b42bcf5f6..e0bf93466 100644
--- a/src/levanter/data/text.py
+++ b/src/levanter/data/text.py
@@ -35,7 +35,7 @@
 from levanter.store.cache import CacheOptions, TreeCache
 from levanter.store.jagged_array import JaggedArrayStore
 from levanter.store.tree_store import TreeStore
-from levanter.utils.fsspec_utils import fsspec_expand_glob
+from levanter.utils.fsspec_utils import expand_glob
 from levanter.utils.hf_utils import num_cpus_used_by_tokenizer
 
 
@@ -588,7 +588,7 @@ def urls_for_split(self, split):
         else:
             raise ValueError(f"Unknown split {split}")
 
-        urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)]
+        urls = [globbed for url in urls for globbed in expand_glob(url)]
         return urls
 
 
diff --git a/src/levanter/eval.py b/src/levanter/eval.py
index 555dd1466..99e132dc2 100644
--- a/src/levanter/eval.py
+++ b/src/levanter/eval.py
@@ -63,7 +63,7 @@ def __init__(
         self.datasets = []
         tag_index: dict[str, int] = {}
         for i, (dataset, tags) in enumerate(datasets):
-            if tags is None:
+            if not tags and len(datasets) > 1:
                 warnings.warn("Dataset has no tags. Giving it an index")
                 tags = [f"domain_{i}"]
             for tag in tags:
@@ -204,14 +204,16 @@ def eval_callback(step: StepInfo):
         }
 
         logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}")
-        for tag, loss in result.tag_macro_losses.items():
-            # don't log leaf tag macro losses because it doesn't mean anything different than micro loss
-            if tag in evaluator.dataset.tag_to_index:
-                continue
-            if not tag:
-                continue
-            log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss
-            logger.info(f"{tag} macro loss: {loss:.3f}")
+        has_tags = len(evaluator.dataset.tag_to_index) > 1  # 1 tag means there's no difference between micro and macro
+        if has_tags:
+            for tag, loss in result.tag_macro_losses.items():
+                # don't log leaf tag macro losses because it doesn't mean anything different than micro loss
+                if tag in evaluator.dataset.tag_to_index:
+                    continue
+                if not tag:
+                    continue
+                log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss
+                logger.info(f"{tag} macro loss: {loss:.3f}")
 
         for tag, loss in result.tag_micro_losses.items():
             if not tag:
@@ -225,11 +227,14 @@ def eval_callback(step: StepInfo):
 
         if tokenizer is not None:
             log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb
-            log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
+            if has_tags:
+                log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
             for tag, bpb in result.tag_micro_bpb.items():
                 log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb
-            for tag, bpb in result.tag_macro_bpb.items():
-                log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb
+
+            if has_tags:
+                for tag, bpb in result.tag_macro_bpb.items():
+                    log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb
 
         levanter.tracker.log_metrics(log_dict, step=step.step)
 
@@ -304,26 +309,29 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample,
                 this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags)  # [Tag]
 
                 mean = state.token_avg_loss.add(this_loss / this_tokens, this_tokens)
-                # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
-                safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
-                mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)
+                state = dataclasses.replace(state, token_avg_loss=mean)
 
-                state = dataclasses.replace(state, token_avg_loss=mean, loss_per_tag=mean_per_tag)
+                if len(self.dataset.tag_to_index) > 0:
+                    # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
+                    safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
+                    mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)
+                    state = dataclasses.replace(state, loss_per_tag=mean_per_tag)
 
                 if self.bytes_per_token is not None:
                     next_tokens = hax.roll(batch.tokens, -1, m.Pos)  # [Batch, Pos], rolled by 1 for next token task
                     bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens)  # [Batch, Pos]
-                    bytes_per_pos = bytes_per_pos * mask  # [Batch, Pos]
-                    bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags)  # [Tag]
-                    total_bytes = hax.sum(bytes_per_tag)
+                    bytes_per_tag = hax.einsum("-> tag", mask, bytes_per_pos, tags)  # [Tag]
+                    this_bytes = hax.einsum("->", bytes_per_pos, mask)  # Scalar
 
                     # log loss -> bits is log2(e) * loss
                     bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e)
-                    bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e)
+                    bpb = this_loss / hax.maximum(this_bytes, 1) * jnp.log2(jnp.e)
 
                     bpb_mean = state.bpb.add(bpb, this_tokens)
-                    bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
-                    state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean)
+                    state = dataclasses.replace(state, bpb=bpb_mean)
+                    if len(self.dataset.tag_to_index) > 0:
+                        bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
+                        state = dataclasses.replace(state, bpb_per_tag=bpb_per_tag_mean)
 
             return state
 
diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py
index b92b6efb5..58413ef2b 100644
--- a/src/levanter/infra/cli_helpers.py
+++ b/src/levanter/infra/cli_helpers.py
@@ -76,6 +76,11 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante
         "/tmp:/tmp",
     ]
 
+    # optionally add multislice env vars (if set by ray runtime env vars)
+    for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]:
+        v = shlex.quote(str(v))
+        docker_command.extend(["-e", v])
+
     for k, v in env.items():
         v = shlex.quote(str(v))
         k = shlex.quote(str(k))
diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py
index 2dc554808..1a9342c54 100644
--- a/src/levanter/infra/ray_tpu.py
+++ b/src/levanter/infra/ray_tpu.py
@@ -3,6 +3,7 @@
 import logging
 import multiprocessing
 import os
+import socket
 import subprocess
 import tempfile
 import time
@@ -10,6 +11,7 @@
 from typing import Callable, Optional, Sequence
 
 import draccus
+import mergedeep
 import ray
 from ray._private.accelerators import TPUAcceleratorManager
 from ray.dashboard.modules.job.sdk import JobSubmissionClient
@@ -104,7 +106,83 @@ def do_run(remote_fn) -> _TpuRunResult:
     return do_run.remote(remote_fn)
 
 
-def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):
+def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef:
+    """
+    Run a remote function on multiple TPU slices.
+
+    Args:
+        remote_fn: A remote function that takes no arguments
+        tpu_type: The type of TPU to run on, e.g. "v4-32"
+        num_slices: The number of slices to run
+
+    Returns:
+        A Ray ObjectRef that represents the result of the function
+    """
+
+    @ray.remote(resources={f"TPU-{tpu_type}-head": 1})
+    class MultisliceActor:
+        def __init__(self):
+            self.pod_name = ray.util.accelerators.tpu.get_current_pod_name()
+            self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
+            self.ip = socket.gethostbyname(socket.gethostname())
+
+        def get_slice_info(self):
+            return self.pod_name, self.num_hosts, self.ip
+
+        def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult:
+            port = 8081
+            mxla_env = {
+                "MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}",
+                "MEGASCALE_NUM_SLICES": str(num_slices),
+                "MEGASCALE_PORT": f"{port}",
+                "MEGASCALE_SLICE_ID": str(slice_id),
+            }
+
+            remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env)
+
+            info = _TpuInfo(tpu_name, "ACTIVE", "TPU")
+            futures = [remote_fn.remote() for _ in range(self.num_hosts)]
+            try:
+                out = ray.get(futures)
+                logger.info("TPU job finished")
+                return TpuSuccess(info, out)
+            except RayError as e:
+                for f in futures:
+                    try:
+                        ray.cancel(f)
+                    except Exception:
+                        logger.exception("Failed to kill job after primary failure")
+                return _handle_ray_error(info, e)
+            except Exception as e:
+                for f in futures:
+                    try:
+                        ray.cancel(f)
+                    except Exception:
+                        logger.exception("Failed to kill job after primary failure")
+                return TpuFailed(info, e)
+
+    actors = [MultisliceActor.remote() for _ in range(num_slices)]  # type: ignore
+    futures = [actor.get_slice_info.remote() for actor in actors]
+    try:
+        logger.info("Getting slice infos...")
+        # also act as a sync step
+        slice_infos = ray.get(futures)
+        logger.info(f"TPU slice infos {slice_infos}")
+    except RayError as e:
+        logger.exception(e)
+        for actor in actors:
+            try:
+                ray.cancel(actor)
+            except Exception:
+                logger.exception("Failed to kill actor after primary failure")
+        return futures
+
+    coordinator_ip = slice_infos[0][2]
+
+    return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)]
+
+
+def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):
     """
     Redecorate a remote function to run on a TPU pod.
 
@@ -120,7 +198,17 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts):
 
     tpu_name = ray.util.accelerators.tpu.get_current_pod_name()  # -> my-tpu
     num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators()  # -> 8
-    remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host})
+
+    # ray doesn't merge the runtime envs properly, so we have to do it ourselves
+    # we need to do a deep merge
+    sources = [e for e in [remote_fn._runtime_env, runtime_env] if e is not None]
+    runtime_env = mergedeep.merge({}, *sources, strategy=mergedeep.Strategy.ADDITIVE)
+
+    remote_fn = remote_fn.options(
+        runtime_env=runtime_env,
+        resources={tpu_name: 1, "TPU": num_tpus_per_host},
+    )
+
     logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host")
     return remote_fn, tpu_name
 
@@ -193,11 +281,107 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
         raise RuntimeError("Failed too many times") from problem
 
 
+def run_on_pod_multislice_resumable(
+    remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10
+):
+    """
+    Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached.
+
+    Args:
+        remote_fn: A remote function that takes no arguments
+        tpu_type: The type of TPU to run on, e.g. "v4-32"
+        num_slices: The number of slices to run
+        max_retries_preemption: The maximum number of times to retry if the job is preempted
+        max_retries_failure: The maximum number of times to retry if the job fails
+
+    Returns:
+        The result of the function (not an ObjectRef)
+
+    """
+    num_failures = 0
+    num_preemptions = 0
+    attempt = 0
+    problem: Exception | None = None
+
+    while num_failures < max_retries_failure and num_preemptions < max_retries_preemption:
+        logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}")
+        attempt += 1
+        problem = None
+        futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices)
+        try:
+            outs = ray.get(futures)
+        except ray.exceptions.RayTaskError as e:
+            for f in futures:
+                try:
+                    ray.cancel(f)
+                except Exception:
+                    logger.exception("Failed to kill job after primary failure")
+            problem = e
+            if "preempted" in str(e).lower():
+                num_preemptions += 1
+                logger.warning(f"Preempted {num_preemptions} times, {e}")
+            else:
+                num_failures += 1
+                logger.warning(f"Failed {num_failures} times", exc_info=e)
+            continue
+        except Exception as e:
+            for f in futures:
+                try:
+                    ray.cancel(f)
+                except Exception:
+                    logger.exception("Failed to kill job after primary failure")
+            problem = e
+            num_failures += 1
+            if num_failures >= max_retries_failure:
+                logger.exception("Failed too many times", exc_info=e)
+                raise e
+            else:
+                logger.warning(f"Failed {num_failures} times", exc_info=e)
+                continue
+
+        if all(isinstance(out, TpuSuccess) for out in outs):
+            results = [out.result for out in outs]
+            logger.info("Success")
+            return results
+        elif any(isinstance(out, TpuPreempted) for out in outs):
+            out = None
+            for o in outs:
+                if isinstance(o, TpuPreempted):
+                    out = o
+            assert out is not None
+            problem = out.error
+            num_preemptions += 1
+            logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem)
+        elif any(isinstance(out, TpuFailed) for out in outs):
+            num_preemptions += 1
+            logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times")
+        elif any(isinstance(out, TpuRunError) for out in outs):
+            out = None
+            for o in outs:
+                if isinstance(o, TpuRunError):
+                    out = o
+            assert out is not None
+            problem = out.error
+            num_preemptions += 1
+            problem = out.error
+            num_failures += 1
+            logger.warning(f"Failed {num_failures} times", exc_info=problem)
+        else:
+            raise RuntimeError(f"Unexpected result: {out}")
+
+    if num_preemptions >= max_retries_preemption:
+        raise RuntimeError("Preempted too many times") from problem
+    elif num_failures >= max_retries_failure:
+        raise RuntimeError("Failed too many times") from problem
+
+
 def _run_command(*args, **kwargs):
     return subprocess.check_call(args, **kwargs)
 
 
-def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10):
+def run_docker_on_pod(
+    image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10
+):
     env = _massage_env(env)
 
     docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name)
@@ -210,9 +394,18 @@ def run_docker():
             logger.exception("Failed to run docker command")
             raise e
 
-    run_on_pod_resumable(
-        ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
-    )
+    if num_slices == 1:
+        run_on_pod_resumable(
+            ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000
+        )
+    else:
+        run_on_pod_multislice_resumable(
+            ray.remote(run_docker),
+            tpu_type=tpu_type,
+            num_slices=num_slices,
+            max_retries_failure=retries,
+            max_retries_preemption=10000,
+        )
 
 
 def _kill_old_container(name):
@@ -351,6 +544,7 @@ class RunDockerOnPodConfig:
     env: dict = dataclasses.field(default_factory=dict)
     name: str = "levanter"
     retries: int = 10
+    node_count: int = 1
 
 
 def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None):
@@ -419,6 +613,8 @@ def main(args: RunDockerOnPodConfig):
         tpu_type=args.tpu_type,
         env=args.env,
         name=args.name,
+        retries=args.retries,
+        num_slices=args.node_count,
     )
 
 
diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py
index ee7e353c5..b1b5d4aaa 100644
--- a/src/levanter/main/train_lm.py
+++ b/src/levanter/main/train_lm.py
@@ -185,13 +185,13 @@ def main(config: TrainLmConfig):
 
         levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)})
 
+        max_eval_examples_per_ds = config.trainer.max_eval_batches
+        if max_eval_examples_per_ds is not None:
+            max_eval_examples_per_ds *= config.trainer.eval_batch_size
+
         if len(tagged_eval_datasets) == 0:
             logger.warning("No evaluation datasets provided.")
         else:
-            max_eval_examples_per_ds = config.trainer.max_eval_batches
-            if max_eval_examples_per_ds is not None:
-                max_eval_examples_per_ds *= config.trainer.eval_batch_size
-
             causal_datasets = [
                 (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags)
                 for ds, tags in tagged_eval_datasets
diff --git a/src/levanter/models/backpack.py b/src/levanter/models/backpack.py
index 2a955395f..4de8accc7 100644
--- a/src/levanter/models/backpack.py
+++ b/src/levanter/models/backpack.py
@@ -401,7 +401,7 @@ def init(Vocab: Axis, config: BackpackConfig, *, key):
         )
 
     @named_call
-    def __call__(
+    def activations(
         self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
     ) -> NamedArray:
         k_embed, k_transformer, k_senses, k_sa = haliax.jax_utils.maybe_rng_split(key, 4)
@@ -428,9 +428,10 @@ def __call__(
         scale = self.config.Senses.size
         hidden_states = hidden_states / scale
 
-        lm_logits = self.embeddings.unembed(hidden_states)
+        return hidden_states
 
-        return lm_logits
+    def get_lm_head(self) -> hax.NamedArray:
+        return self.embeddings.token_embeddings
 
     def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None):
         new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py
index af5cc44be..c38acf5ef 100644
--- a/src/levanter/models/gemma.py
+++ b/src/levanter/models/gemma.py
@@ -339,6 +339,9 @@ def vocab_size(self) -> int:
     def Vocab(self) -> Axis:
         return self.embeddings.Vocab
 
+    def get_lm_head(self) -> hax.NamedArray:
+        return self.embeddings.token_embeddings.weight
+
     @classmethod
     def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel":
         k_t, k_emb = jrandom.split(key, 2)
@@ -346,7 +349,7 @@ def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel":
         embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb)
         return GemmaLMHeadModel(transformer, embeddings)
 
-    def __call__(
+    def activations(
         self,
         input_ids: NamedArray,
         attn_mask: Optional[Union[NamedArray, AttentionMask]] = None,
@@ -363,8 +366,7 @@ def __call__(
         """
         x = self.embeddings.embed(input_ids)
         x = self.transformer(x, attn_mask=attn_mask, key=key)
-        lm_logits = self.embeddings.unembed(x)
-        return lm_logits
+        return x
 
     def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]":
         new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py
index a921074e9..28e878193 100644
--- a/src/levanter/models/gpt2.py
+++ b/src/levanter/models/gpt2.py
@@ -391,15 +391,17 @@ def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel":
 
         return Gpt2LMHeadModel(transformer, embeddings)
 
-    def __call__(
+    def activations(
         self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
     ) -> NamedArray:
         k_embed, k_transformer = haliax.jax_utils.maybe_rng_split(key, 2)
         x = self.embeddings.embed(input_ids, key=k_embed)
         x = self.transformer(x, attn_mask, key=k_transformer)
-        lm_logits = self.embeddings.unembed(x)
 
-        return lm_logits
+        return x
+
+    def get_lm_head(self) -> hax.NamedArray:
+        return self.embeddings.token_embeddings.weight
 
     def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None) -> "Gpt2LMHeadModel":
         new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py
index 1e09ffbc5..85861da6a 100644
--- a/src/levanter/models/llama.py
+++ b/src/levanter/models/llama.py
@@ -557,6 +557,31 @@ def __call__(
             lm_logits = self.embeddings.unembed(x)
         return lm_logits
 
+    def activations(
+        self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
+    ) -> NamedArray:
+        """
+        Compute the activations for the next token in a sequence.
+        Args:
+            input_ids: token IDs with shape {Pos}
+            attn_mask: attention mask with shape {Pos, KeyPos}
+            key: PRNGKey for random number generation
+
+        Returns:
+            NamedArray: activations with shape {Pos, Embed}
+
+        """
+        x = self.embeddings.embed(input_ids)
+        x = self.transformer(x, attn_mask=attn_mask, key=key)
+
+        return x
+
+    def get_lm_head(self) -> hax.NamedArray:
+        if self.lm_head is None:
+            return self.embeddings.token_embeddings.weight
+        else:
+            return self.lm_head.weight
+
     def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
         new_Vocab = self.Vocab.resize(new_size)
         k1, k2 = maybe_rng_split(key, 2)
diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py
index 468f6a4a4..911e74b09 100644
--- a/src/levanter/models/lm_model.py
+++ b/src/levanter/models/lm_model.py
@@ -64,6 +64,18 @@ def KeyPos(self) -> Axis:
     def Pos(self) -> Axis:
         pass
 
+    @property
+    @abc.abstractmethod
+    def Embed(self) -> Axis:
+        pass
+
+    cross_entropy_block_size: Optional[int] = 64000
+    """
+    The block size for computing cross-entropy loss. This is the number of tokens that are processed together
+    in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large
+    value because it usually faster to compute the loss in larger blocks.
+    """
+
     def flops_per_token(self, vocab_size: int) -> Optional[float]:
         return None
 
@@ -94,17 +106,58 @@ def Pos(self) -> Axis:
     def KeyPos(self) -> Axis:
         return self.config.KeyPos
 
+    @property
+    def Embed(self) -> Axis:
+        return self.config.Embed
+
     @classmethod
     @abc.abstractmethod
     def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[LmConfigT]":
         pass
 
-    @abc.abstractmethod
     def __call__(
         self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
     ) -> NamedArray:
+        """
+        Compute the logits for the next token in a sequence.
+        Args:
+            input_ids: token IDs with shape [..., Pos]
+            attn_mask: attention mask with shape [..., Pos, KeyPos]
+            key: PRNGKey for random number generation
+
+        Returns:
+            NamedArray: logits with shape [..., Pos, Vocab]
+
+        """
+        x = self.activations(input_ids, attn_mask, key=key)
+        lm_logits = hax.dot(x, self.get_lm_head(), axis=self.Embed)
+
+        return lm_logits
+
+    @abc.abstractmethod
+    def activations(
+        self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
+    ) -> NamedArray:
+        """
+        Compute the activations for the next token in a sequence.
+        Args:
+            input_ids: token IDs with shape {Pos}
+            attn_mask: attention mask with shape {Pos, KeyPos}
+            key: PRNGKey for random number generation
+
+        Returns:
+            NamedArray: activations with shape {Pos, Embed}
+
+        """
         pass
 
+    @abc.abstractmethod
+    def get_lm_head(self) -> hax.NamedArray:
+        """
+        The language modeling head of the model. Should have shape {Embed, Vocab}.
+        """
+        raise NotImplementedError("get_lm_head not implemented")
+
     @abc.abstractmethod
     def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]":
         """
@@ -133,19 +186,21 @@ def compute_next_token_loss(
     across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
     reduced, and the result is a named array with axes (*batch axes, sequence_length).
     """
-    logits = model(example.tokens, example.attn_mask, key=key)
-    if loss_dtype is not None:
-        logits = logits.astype(loss_dtype)
+    activations = model.activations(example.tokens, example.attn_mask, key=key)
 
     loss = next_token_loss(
         model.Pos,
+        model.Embed,
         model.Vocab,
-        logits,
+        activations,
+        model.get_lm_head(),
         example.tokens,
         loss_mask=example.loss_mask,
         reduction=reduction,
         reduction_axis=reduction_axis,
         logsumexp_weight=logsumexp_weight,
+        dtype=loss_dtype,
+        block_size=model.config.cross_entropy_block_size,
     )
 
     return loss
diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py
index 1ef7e81f9..154fc66ac 100644
--- a/src/levanter/models/loss.py
+++ b/src/levanter/models/loss.py
@@ -1,5 +1,8 @@
+import functools
 from typing import Optional
 
+import equinox
+import jax
 import jax.numpy as jnp
 
 import haliax as hax
@@ -9,34 +12,77 @@
 
 def next_token_loss(
     Pos: hax.AxisSelector,
+    Embed: hax.AxisSelector,
     Vocab: hax.AxisSelector,
-    pred_ids: NamedArray,
+    pred_embeddings: NamedArray,
+    pred_lm_head: NamedArray,
     true_ids: NamedArray,
     loss_mask: Optional[NamedArray] = None,
     reduction: Optional[hax.ReductionFunction] = hax.mean,
     reduction_axis: Optional[hax.AxisSelection] = None,
     logsumexp_weight: Optional[float] = None,
-):
-    Pos, Vocab = pred_ids.resolve_axis((Pos, Vocab))
-    # need to roll the target tokens back by one so that each token is predicting the next token
+    block_size: Optional[int] = None,
+    dtype: Optional[jnp.dtype] = jnp.float32,
+) -> NamedArray:
+    """
+    Compute the next token loss with optional block-wise processing.
+
+    Args:
+        Pos (hax.AxisSelector): Position axis selector.
+        Vocab (hax.AxisSelector): Vocabulary axis selector.
+        pred_embeddings (NamedArray): Predicted embeddings.
+        pred_lm_head (NamedArray): Language model head weights.
+        true_ids (NamedArray): True token IDs.
+        loss_mask (Optional[NamedArray]): Mask to apply to the loss.
+        reduction (Optional[hax.ReductionFunction]): Reduction function.
+        reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction.
+        logsumexp_weight (Optional[float]): Weight for logsumexp penalty.
+        block_size (Optional[int]): Size of each block for processing.
+
+    Returns:
+        NamedArray: Computed loss.
+    """
+    # Resolve axes
+    Pos = pred_embeddings.resolve_axis(Pos)
+    Vocab = pred_lm_head.resolve_axis(Vocab)
+
+    # Shift target tokens to predict the next token
     target_y = hax.roll(true_ids, -1, Pos)
-    target_y = hax.nn.one_hot(target_y, Vocab, dtype=pred_ids.dtype)  # type: ignore
 
-    # one everywhere except the last token
+    # Create a mask that excludes the last token
     not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)  # type: ignore
     if loss_mask is not None:
         loss_mask = loss_mask * not_last_loss_mask
     else:
         loss_mask = not_last_loss_mask
 
-    return cross_entropy_and_logsumexp_penalty(
-        pred_ids,
-        Vocab,
-        target_y,
+    if block_size is None:
+        # Full softmax computation
+        logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype)
+        target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype)
+        return cross_entropy_and_logsumexp_penalty(
+            logits,
+            Vocab,
+            target_y_full,
+            reduction=reduction,
+            reduction_axis=reduction_axis,
+            where=loss_mask,
+            logsumexp_weight=logsumexp_weight,
+        )
+
+    # Compute the loss with optional block-wise processing
+    return fused_cross_entropy_loss_and_logsumexp_penalty(
+        pred_embeddings,
+        pred_lm_head,
+        Contract=Embed,
+        Label=Vocab,
+        target_y=target_y,
         reduction=reduction,
         reduction_axis=reduction_axis,
         where=loss_mask,
         logsumexp_weight=logsumexp_weight,
+        block_size=block_size,
+        dtype=dtype,
     )
 
 
@@ -58,3 +104,345 @@ def cross_entropy_and_logsumexp_penalty(
         loss = loss + logsumexp_weight * (log_normalizers**2)
 
     return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where)
+
+
+def fused_cross_entropy_loss_and_logsumexp_penalty(
+    pred_embeddings: NamedArray,
+    pred_lm_head: NamedArray,
+    Contract: hax.AxisSelector,
+    Label: hax.AxisSelector,
+    target_y: NamedArray,
+    *,
+    reduction: Optional[hax.ReductionFunction] = hax.mean,
+    reduction_axis: Optional[hax.AxisSelection] = None,
+    where: Optional[NamedArray] = None,
+    logsumexp_weight: float | None = 0.0,
+    block_size: int,
+    dtype: Optional[jnp.dtype] = jnp.float32,
+) -> NamedArray:
+    """
+    Compute the cross-entropy loss and logsumexp penalty using embeddings and lm_head,
+    with optional block-wise processing.
+
+    Args:
+        pred_embeddings (NamedArray): Predicted embeddings.
+        pred_lm_head (NamedArray): Language model head weights.
+        Contract (hax.AxisSelector): Axis to contract over.
+        Label (hax.AxisSelector): Label (Vocab) axis.
+        target_y (NamedArray): One-hot encoded target tokens.
+        reduction (Optional[hax.ReductionFunction]): Reduction function.
+        reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction.
+        where (Optional[NamedArray]): Mask to apply to the loss.
+        logsumexp_weight (float): Weight for logsumexp penalty.
+        block_size (int): Size of each block for processing.
+        dtype (Optional[jnp.dtype]): Data type for the loss.
+
+    Returns:
+        NamedArray: Computed loss.
+    """
+
+    # Block-wise softmax computation
+    loss, log_normalizers = _blockwise_cross_entropy_loss(
+        (pred_embeddings, pred_lm_head), Contract, Label, target_y, block_size, dtype=dtype
+    )
+
+    if logsumexp_weight is not None and (not isinstance(logsumexp_weight, (int, float)) or logsumexp_weight != 0.0):
+        loss = loss + logsumexp_weight * (log_normalizers**2)
+
+    return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where)
+
+
+@equinox.filter_custom_vjp
+def _blockwise_cross_entropy_loss(
+    # pred_embeddings: NamedArray,
+    # pred_lm_head: NamedArray,
+    pred: tuple[NamedArray, NamedArray],
+    Contract: hax.Axis,
+    Label: hax.Axis,
+    labels_y: NamedArray,
+    block_size: int,
+    dtype: Optional[jnp.dtype],
+) -> tuple[NamedArray, NamedArray]:
+    """
+    Compute cross-entropy loss and log normalizers in a block-wise manner without materializing the full logits.
+
+    Args:
+        pred_embeddings (NamedArray): Predicted embeddings.
+        pred_lm_head (NamedArray): Language model head weights.
+        Contract (hax.Axis): Axis to contract over.
+        Label (hax.AxisSelector): Label (Vocab) axis.
+        labels_y (NamedArray): label tensor.
+        block_size (int): Size of each block for processing.
+        dtype (Optional[jnp.dtype]): Data type for the loss.
+
+    Notes:
+        labels_y being anything other than the label tensor would remove any benefits
+
+        TODO: but if XLA smart enough to optimize it out?
+
+    Returns:
+        tuple[NamedArray, NamedArray]: tuple of loss and log_normalizers.
+    """
+
+    return _block_cross_entropy_forward(None, pred, Contract, Label, labels_y, block_size, dtype)[0]
+
+
+def _block_cross_entropy_forward(
+    ignore,
+    pred: tuple[NamedArray, NamedArray],
+    Contract: hax.Axis,
+    Label: hax.Axis,
+    labels_y: NamedArray,
+    block_size: int,
+    dtype: Optional[jnp.dtype],
+) -> tuple[tuple[NamedArray, NamedArray], tuple[NamedArray]]:
+    """
+    Forward pass for block-wise cross-entropy loss.
+
+    This function computes the cross-entropy loss and log-sum-exp (`log_z`) in a block-wise manner
+    to maintain memory efficiency by processing subsets of the vocabulary at a time.
+
+    Args:
+        ignore: Placeholder argument (unused).
+        pred (Tuple[NamedArray, NamedArray]): Tuple containing predicted embeddings and language model head weights.
+        Contract (hax.Axis): Axis to contract over (e.g., embedding axis).
+        Label (hax.Axis): Label axis (e.g., vocabulary axis).
+        labels_y (NamedArray): True target labels [Batch, Seq].
+        block_size (int): Number of vocabulary tokens per block.
+        dtype (Optional[jnp.dtype]): Data type for the computations.
+
+    Returns:
+        Tuple:
+            - Tuple[NamedArray, NamedArray]: Computed loss and logsumexp.
+            - Tuple[NamedArray]: Residuals needed for the backward pass.
+    """
+    vocab_size = Label.size
+
+    pred_embeddings, pred_lm_head = pred
+
+    #
+    # if num_blocks == 1:
+    #     # No need for block-wise processing
+    #     logits = hax.dot(pred_embeddings, pred_lm_head, axis=Contract)
+    #     labels_y = hax.nn.one_hot(labels_y, Label, dtype=pred_embeddings.dtype)
+    #     return cross_entropy_loss_and_log_normalizers(logits, Label, labels_y)
+    #
+    # ensure block size divides vocab size
+    if vocab_size % block_size != 0:
+        has_stragglers = True
+    else:
+        has_stragglers = False
+
+    num_blocks = vocab_size // block_size
+
+    # Initialize accumulators: loss, logsumexp, max_logits
+    initial_O = hax.zeros(labels_y.axes)
+    initial_logsumexp = hax.full(labels_y.axes, -jnp.inf)
+    initial_max = hax.full(labels_y.axes, -jnp.inf)
+    # We don't need this b/c we're using one-hot targets
+    # initial_sumV = hax.full(labels_y.axes, 0.0)
+
+    def process_block(block_idx, acc, current_block_size):
+        """
+        Process a single block of the Vocab dimension.
+
+        Args:
+            block_idx (int): Index of the current block.
+            acc (tuple[NamedArray, NamedArray, jnp.ndarray]): Accumulators for loss, logsumexp, and max logits.
+            current_block_size (int): Size of the current block (used for stragglers).
+
+        Returns:
+            tuple[NamedArray, NamedArray, jnp.ndarray]: Updated accumulators
+        """
+        loss, logsumexp_prev, max_logit_prev = acc
+
+        start = block_idx * block_size
+        Block = Label.resize(current_block_size)
+
+        # Materialize the logits for the current block
+        lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)]  # [Contract, Block]
+        logits_b = hax.dot(
+            pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype
+        )  # [Batch, Seq, Block]
+
+        # Update max and logsumexp
+        max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block))  # [Batch, Seq]
+        # reweight the previous logsumexp by the new max, fold in the new logits' contribution
+        logsumexp = max_logit + hax.log(
+            hax.exp(logsumexp_prev - max_logit) + hax.sum(hax.exp(logits_b - max_logit), axis=Block)
+        )  # [Batch, Seq]
+
+        # Materialize the target for the current block (one-hot)
+        target_y_b = _block_one_hot(Block, start, labels_y, logits_b.dtype)  # [Batch, Seq, Block]
+
+        # Update sumV. This is actually unnecessary if we're using one-hot targets
+        # sV = sV_prev + hax.sum(target_y_b, axis=Label.name)
+
+        loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype)  # [Batch, Seq]
+
+        return loss, logsumexp, max_logit  # , sV
+
+    if num_blocks == 0:
+        o = initial_O
+        log_z = initial_logsumexp
+        max_logits = initial_max
+    elif num_blocks == 1:
+        o, log_z, max_logits = process_block(0, (initial_O, initial_logsumexp, initial_max), vocab_size)
+    else:
+        (o, log_z, max_logits) = jax.lax.fori_loop(
+            lower=0,
+            upper=num_blocks,
+            body_fun=functools.partial(process_block, current_block_size=block_size),
+            init_val=(initial_O, initial_logsumexp, initial_max),  # , initial_sumV
+        )
+
+    if has_stragglers:
+        # Handle the stragglers
+        remainder_size = vocab_size - num_blocks * block_size
+        o, log_z, _ = process_block(num_blocks, (o, log_z, max_logits), remainder_size)
+
+    # unnecessary if we're using one-hot targets
+    # logz_outer = hax.einsum("->...", log_z, sum_v)
+    o = log_z - o
+
+    return (o, log_z), (log_z,)
+
+
+def _block_cross_entropy_backward(
+    residuals: tuple[NamedArray,],
+    grad_in: tuple[NamedArray, NamedArray],
+    ignore,
+    pred: tuple[NamedArray, NamedArray],
+    Contract: hax.Axis,
+    Label: hax.Axis,
+    labels_y: NamedArray,
+    block_size: int,
+    dtype: Optional[jnp.dtype],
+) -> tuple[NamedArray, NamedArray]:
+    """
+    Compute the gradients of the block-wise cross-entropy loss.
+
+    Args:
+        residuals (tuple[NamedArray, NamedArray]): Residuals from the forward pass.
+        grad_in (tuple[NamedArray, NamedArray]): Incoming gradients.
+        pred (tuple[NamedArray, NamedArray]): Predictions.
+        Contract (hax.Axis): Axis to contract over.
+        Label (hax.Axis): Label axis.
+        labels_y (NamedArray): Target labels.
+        block_size (int): Size of each block.
+        dtype (Optional[jnp.dtype]): Data type for the loss.
+
+    Returns:
+        tuple[NamedArray, NamedArray]: Gradients.
+    """
+
+    (log_z,) = residuals
+    grad_loss, grad_log_z = grad_in
+
+    vocab_size = Label.size
+
+    pred_embeddings, pred_lm_head = pred
+
+    if vocab_size % block_size != 0:
+        has_stragglers = True
+    else:
+        has_stragglers = False
+
+    num_blocks = vocab_size // block_size
+
+    grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype)
+    grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype)
+
+    def process_block(block_idx, acc, current_block_size):
+        """
+        Process a single block of the Vocab dimension.
+
+        Args:
+            block_idx (int): Index of the current block.
+            acc (tuple[NamedArray, NamedArray]): Accumulators for gradients.
+            current_block_size (int): Size of the current block (used for stragglers).
+
+        Returns:
+            tuple[NamedArray, NamedArray]: Updated accumulators.
+        """
+        grad_embeddings_prev, grad_lm_head_prev = acc
+
+        start = block_idx * block_size
+        Block = Label.resize(current_block_size)
+
+        # Materialize the logits for the current block
+        lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)]  # [Contract, Block]
+        logits_b = hax.dot(
+            pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype
+        )  # [Batch, Seq, Block]
+
+        # Materialize the target for the current block (one-hot)
+        target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype)  # [Batch, Seq, Block]
+
+        # materialize the softmax for the current block
+        p_b = hax.exp(logits_b - log_z)  # [Batch, Seq, Block]
+
+        delta_b = p_b - target_y_block
+
+        #  # dLoss/dL = g_loss * delta_b + g_log_z * probs_b
+        #         # = g_loss * (probs_b - Y) + g_log_z * probs_b
+        #         # = (g_loss + g_log_z) * probs_b - g_loss * Y
+
+        # Compute gradients. We get None if the gradient is not provided.
+        if grad_loss.array is not None:
+            dLoss = grad_loss * delta_b  # [Batch, Seq, Block]
+        else:
+            dLoss = 0.0
+
+        # Add the gradient of the logsumexp term (should be None if not provided)
+        if grad_log_z.array is not None:
+            dLoss += grad_log_z * p_b  # [Batch, Seq, Block]
+
+        # Compute gradients for the current block
+        # embeddings has shape [Batch, Seq, Embed], so we need to eliminate Block
+        g_embeddings_b = hax.dot(
+            dLoss, lm_head_b, axis=Block, preferred_element_type=grad_embeddings.dtype
+        )  # [Batch, Seq, Embed]
+
+        # lm_head has shape [Block, Embed], so we need to eliminate Batch, Seq, etc.
+        eliminated_axes_W = hax.axis.without_axes(pred_embeddings.axes, lm_head_b.axes)
+        g_lm_head_b = hax.dot(
+            dLoss, pred_embeddings, axis=eliminated_axes_W, preferred_element_type=grad_lm_head_prev.dtype
+        )  # [Block, Embed]
+
+        g_lm_head = grad_lm_head_prev.at[Label, hax.dslice(start, Block)].set(g_lm_head_b)
+        g_embeddings = grad_embeddings_prev + g_embeddings_b
+
+        return g_embeddings, g_lm_head
+
+    if num_blocks == 0:
+        pass
+    elif num_blocks == 1:
+        grad_embeddings, grad_lm_head = process_block(0, (grad_embeddings, grad_lm_head), vocab_size)
+    else:
+        grad_embeddings, grad_lm_head = jax.lax.fori_loop(
+            lower=0,
+            upper=num_blocks,
+            body_fun=functools.partial(process_block, current_block_size=block_size),
+            init_val=(grad_embeddings, grad_lm_head),
+        )
+
+    if has_stragglers:
+        # Handle the stragglers
+        remainder_size = vocab_size - num_blocks * block_size
+        grad_embeddings, grad_lm_head = process_block(num_blocks, (grad_embeddings, grad_lm_head), remainder_size)
+
+    return grad_embeddings.astype(pred_embeddings.dtype), grad_lm_head.astype(pred_lm_head.dtype)
+
+
+_blockwise_cross_entropy_loss.def_fwd(_block_cross_entropy_forward)
+_blockwise_cross_entropy_loss.def_bwd(_block_cross_entropy_backward)
+
+
+def _block_one_hot(LBlock, block_start, labels, dtype):
+    end = block_start + LBlock.size
+    target_is_in_this_block = hax.logical_and(labels >= block_start, labels < end)
+    target_y_block = hax.nn.one_hot(labels - block_start, LBlock, dtype=dtype)
+    # 0 out the logits that are not in this block
+    target_y_block *= target_is_in_this_block
+    return target_y_block
diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py
index b48bfbe91..764e18aea 100644
--- a/src/levanter/models/mistral.py
+++ b/src/levanter/models/mistral.py
@@ -175,7 +175,11 @@ def init(cls, Vocab: Axis, config: MistralConfig, *, key) -> "MistralLMHeadModel
         lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True)
         return MistralLMHeadModel(transformer, embeddings, lm_head)
 
-    def __call__(
+    def get_lm_head(self) -> hax.NamedArray:
+        assert self.lm_head.bias is None
+        return self.lm_head.weight
+
+    def activations(
         self,
         input_ids: NamedArray,
         attn_mask: Optional[Union[NamedArray, AttentionMask]] = None,
@@ -193,8 +197,7 @@ def __call__(
         k_t, k_head = maybe_rng_split(key, 2)
         x = self.embeddings.embed(input_ids)
         x = self.transformer(x, attn_mask=attn_mask, key=k_t)
-        lm_logits = self.lm_head(x, key=k_head)
-        return lm_logits
+        return x
 
     def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MistralConfig]":
         new_Vocab = self.Vocab.resize(new_size)
diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py
index 00044a4ed..0809d9d23 100644
--- a/src/levanter/models/mpt.py
+++ b/src/levanter/models/mpt.py
@@ -447,14 +447,15 @@ def init(cls, Vocab: Axis, config: MptConfig, *, key):
         return MptLmHeadModel(wte, transformer, config)
 
     @named_call
-    def __call__(
+    def activations(
         self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray], *, key=None
     ) -> NamedArray:
         hidden_states = self.wte.embed(input_ids)
         hidden_states = self.transformer(hidden_states, attention_mask=attn_mask, key=key)
-        output_logits = self.wte.unembed(hidden_states)
+        return hidden_states
 
-        return output_logits
+    def get_lm_head(self) -> hax.NamedArray:
+        return self.wte.weight
 
     def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "MptLmHeadModel":
         if new_size == self.vocab_size:
diff --git a/src/levanter/models/rotary.py b/src/levanter/models/rotary.py
index 07657e5ff..55bbf3fcb 100644
--- a/src/levanter/models/rotary.py
+++ b/src/levanter/models/rotary.py
@@ -157,6 +157,7 @@ def to_hf_config(self) -> tuple[float, dict]:
             "low_freq_factor": self.low_freq_factor,
             "high_freq_factor": self.high_freq_factor,
             "original_max_position_embeddings": self.original_max_position_embeddings,
+            "rope_type": "llama3",
         }
 
 
diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py
index 45265c994..558bbfceb 100644
--- a/src/levanter/store/cache.py
+++ b/src/levanter/store/cache.py
@@ -3,6 +3,7 @@
 import copy
 import dataclasses
 import logging as pylogging
+import operator
 import os
 import pprint
 import random
@@ -12,26 +13,30 @@
 from concurrent.futures import Future as threading_Future
 from contextlib import AbstractContextManager
 from dataclasses import dataclass
-from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, TypeVar, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
 
 import deepdiff
 import fsspec.core
-import humanfriendly
 import jax
+import numpy as np
 import pyarrow as pa
 import ray
+import tensorstore as ts
 from dataclasses_json import dataclass_json
 from fsspec import AbstractFileSystem
 from jaxtyping import PyTree
 from ray.actor import ActorHandle
+from ray.runtime_env import RuntimeEnv
+from tqdm_loggable.auto import tqdm
 
+from levanter.data import batched
 from levanter.data.dataset import AsyncDataset
-from levanter.store._prefetch_actor import QueueEmpty, RayPrefetchQueue
-from levanter.utils.py_utils import Stopwatch
 
-from ..data._preprocessor import BatchProcessor, BatchProcessorPool, BatchResult, dict_from_record_batch
+from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch
 from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor
 from ..data.sharded_datasource import ShardedDataSource
+from ..utils.fsspec_utils import exists as fsspec_exists
+from ..utils.fsspec_utils import remove as fsspec_remove
 from ..utils.ray_utils import (
     ExceptionInfo,
     RefBox,
@@ -40,8 +45,7 @@
     log_failures_to,
     ser_exc_info,
 )
-from ..utils.thread_utils import ExceptionTrackingThread
-from .jagged_array import PreparedBatch
+from .jagged_array import JaggedArrayStore, PreparedBatch
 from .tree_store import TreeStore
 
 
@@ -69,23 +73,20 @@ class CacheOptions:
     """
 
     num_shard_groups: Optional[int] = 128
-    """Number of groups to divide the shards into. This is used to parallelize the cache building process without
-    overloading Ray. If None, all shards will be in their own group."""
-    shard_order_randomization_key: Optional[int] = 0
-    """A key used to randomize the order of the shards before building and grouping."""
-    batch_size: int = 128
-    """The batch size to use when processing the data. This is used to control the memory usage of the cache building
-    process. Lower values will use less memory but take somewhat longer to build the cache."""
 
     # the below options don't actually impact the cache's result, but do impact construction
     target_size_per_flush: int | str = "512MB"
     """The number of bytes to buffer before flushing to disk. This is used to control the memory usage of the cache
     building process. Lower values will use less memory but could take somewhat longer to build the cache."""
-    prefetch_per_group: int = 4
-    """The number of batches to prefetch per group. This is used to keep the processors busy and to reduce the time"""
+
+    batch_size: int = 128
 
     @property
     def target_bytes_per_flush(self):
+        if isinstance(self.target_size_per_flush, int):
+            return self.target_size_per_flush
+        import humanfriendly
+
         return humanfriendly.parse_size(self.target_size_per_flush)
 
     @staticmethod
@@ -99,14 +100,14 @@ def no_fanciness(batch_size: Optional[int] = None):
         """
         if batch_size is None:
             batch_size = 128
-        return CacheOptions(num_shard_groups=None, shard_order_randomization_key=None, batch_size=batch_size)
+        return CacheOptions(num_shard_groups=None, batch_size=batch_size)
 
     @staticmethod
     def one_group():
         """
         For testing, disables all the fancy features of the cache. This makes it easier to predict the behavior
         """
-        return CacheOptions(num_shard_groups=1, shard_order_randomization_key=None, batch_size=128)
+        return CacheOptions(num_shard_groups=1, batch_size=128)
 
 
 def build_or_load_cache(
@@ -116,7 +117,6 @@ def build_or_load_cache(
     await_finished: bool = True,
     monitors: Optional[Sequence["MetricsMonitor"]] = None,
     options: CacheOptions = CacheOptions.default(),
-    force_flush: bool = False,
     split: str = "test",
 ) -> "TreeCache[U]":
     """
@@ -144,8 +144,6 @@ def build_or_load_cache(
 
         options: Configuration for the cache. This is used to configure a few parts of the cache creation process
 
-        force_flush: for testing, forces the cache to flush after every batch. This is useful for testing.
-
     Returns:
        (TreeCache) A TreeCache object that can be used to read the cache.
 
@@ -156,7 +154,6 @@ def build_or_load_cache(
         shard_source=input_shards,
         processor=processor,
         options=options,
-        force_flush=force_flush,
         split=split,
     )
 
@@ -320,12 +317,11 @@ def build_or_load(
         shard_source: ShardedDataSource[T],
         processor: BatchProcessor[T, U],
         options: Optional["CacheOptions"] = None,
-        force_flush: bool = False,
         split: str = "test",
     ) -> "TreeCache[U]":
         if options is None:
             options = CacheOptions.default()
-        metadata = CacheMetadata(options=options, preprocessor_metadata=processor.metadata)
+        metadata = CacheMetadata(preprocessor_metadata=processor.metadata)
         try:
             return TreeCache.load(cache_dir, processor.output_exemplar, metadata)
         except FileNotFoundError:
@@ -334,8 +330,6 @@ def build_or_load(
                 shard_source=shard_source,
                 processor=processor,
                 options=options,
-                force_flush=force_flush,
-                split=split,
             )
             return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker)
 
@@ -489,13 +483,11 @@ class CacheLedger:
     is_finished: bool = False
     finished_shards: List[str] = dataclasses.field(default_factory=list)
     field_counts: Dict[str, int] = dataclasses.field(default_factory=dict)
-    metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata(CacheOptions(), {}))
+    metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata({}))
 
     @staticmethod
-    def load_or_initialize(
-        cache_dir: str, source: ShardedDataSource, processor: BatchProcessor, config: "CacheOptions"
-    ):
-        metadata = CacheMetadata(options=config, preprocessor_metadata=processor.metadata)
+    def load_or_initialize(cache_dir: str, source: ShardedDataSource, processor: BatchProcessor):
+        metadata = CacheMetadata(preprocessor_metadata=processor.metadata)
         try:
             return CacheLedger.load(cache_dir, metadata)
         except FileNotFoundError:
@@ -531,7 +523,6 @@ def _serialize_and_commit(self, cache_dir):
 @dataclass_json
 @dataclass(frozen=True)
 class CacheMetadata:
-    options: CacheOptions = CacheOptions.default()
     preprocessor_metadata: Optional[dict[str, Any]] = None
 
     def compare_to(self, other: "CacheMetadata") -> deepdiff.DeepDiff:
@@ -552,13 +543,6 @@ def empty():
         return CacheMetadata()
 
 
-@dataclass
-class _ShardStatus:
-    shard_name: str
-    num_rows_committed: int
-    is_finished: bool
-
-
 class SerialCacheWriter(AbstractContextManager):
     """
     Writes TreeCache-compatible caches to disk. This is a serial version of TreeCacheWriter that doesn't use Ray.
@@ -616,91 +600,6 @@ def write_batch(self, batch: BatchResult):
         self._tree_store.extend(cbatch)
 
 
-class ShardedCacheWriter:
-    """
-    Similar to SerialCacheWriter, but tracks shard metadata.
-
-    Similar to _OrderedCacheWriter, it also supports resuming, and it
-    groups together batches before writing (at some interval) in order to improve performance.
-    """
-
-    def __init__(
-        self,
-        cache_dir: str,
-        initial_ledger: CacheLedger,
-        exemplar: T,
-        on_write: Optional[Callable[[CacheLedger], None]] = None,
-    ):
-        self.cache_dir = cache_dir
-        self._on_write = on_write
-
-        self._ledger = copy.deepcopy(initial_ledger)
-
-        self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a")  # type: ignore
-        self._tree_store.trim_to_size(self._ledger.total_num_rows)
-
-    @property
-    def ledger(self):
-        return self._ledger
-
-    # we have both versions b/c we need this one for actors
-    def get_ledger(self):
-        return self._ledger
-
-    @property
-    def is_finished(self):
-        return self._ledger.is_finished
-
-    def finish_shard(self, shard_name: str, num_rows: int):
-        current_rows = self._ledger.shard_rows.get(shard_name, 0)
-        if current_rows != num_rows:
-            raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}")
-
-        self._ledger.finished_shards.append(shard_name)
-        self._ledger._serialize_and_commit(self.cache_dir)
-
-    def write_prepared_batch(self, shard_counts: Mapping[str, int], batch: PyTree[PreparedBatch]):
-        if self.is_finished:
-            raise RuntimeError("Cannot write to a finished cache")
-        self._tree_store.extend_with_batch(batch)
-
-        for shard, num_rows in shard_counts.items():
-            self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows
-
-        total_rows = self._ledger.total_num_rows + sum(shard_counts.values())
-        self._ledger.total_num_rows = total_rows
-        self._ledger._serialize_and_commit(self.cache_dir)
-
-        if self._on_write:
-            self._on_write(self._ledger)
-
-    def write_batch(self, shard_name: str, batch: BatchResult):
-        if self.is_finished:
-            raise RuntimeError("Cannot write to a finished cache")
-
-        if isinstance(batch, pa.RecordBatch):
-            raise NotImplementedError("Only non-RecordBatch batches are supported for now")
-
-        batch = _canonicalize_batch(batch)  # type: ignore
-        prepared = self._tree_store.batch_preparer(batch)
-
-        return self.write_prepared_batch({shard_name: len(batch)}, prepared)
-
-    def finish(self):
-        # if successful, write the ledger
-        logger.info("Finished writing cache")
-        # check that all shards are finished
-        if set(self._ledger.shard_rows.keys()) != set(self._ledger.finished_shards):
-            raise ValueError("Not all shards are finished")
-
-        self._ledger.is_finished = True
-        self._ledger._serialize_and_commit(self.cache_dir)
-        if self._on_write:
-            self._on_write(self._ledger)
-
-        return self._tree_store
-
-
 def _serialize_json_and_commit(path, obj):
     # just to be paranoid, we write to a temp file and then rename it
     # TODO: probably we could do better here
@@ -711,11 +610,10 @@ def _serialize_json_and_commit(path, obj):
         fs.copy(path, f"{path}.bak")
 
     for i in range(10):
-        with fsspec.open(f"{path}.tmp", "w") as file:
-            file.write(obj.to_json())
 
         try:
-            fs.rename(f"{path}.tmp", path)
+            with fsspec.open(path, "w") as file:
+                file.write(obj.to_json())
             break
         except FileNotFoundError:
             # this happens for some reason sometimes. It makes no sense.
@@ -724,7 +622,9 @@ def _serialize_json_and_commit(path, obj):
             pass
 
 
-@ray.remote(num_cpus=0.1)  # keep this small b/c it doesn't do a lot
+@ray.remote(
+    num_cpus=0.1, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"})
+)  # keep this small b/c it doesn't do a lot
 class _TreeStoreCacheBuilder(SnitchRecipient):
     """
     Actor that coordinates the building of a cache. It spins up a bunch of workers to read from each shard
@@ -736,11 +636,9 @@ def __init__(
         self,
         cache_dir: str,
         name: str,
-        split: str,  # to workaround https://github.com/ray-project/ray/issues/44083
         source: ShardedDataSource[T],
         processor: BatchProcessor[T, U],
         options: CacheOptions,
-        force_flush: bool,
     ):
         pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT)
         self.logger = pylogging.getLogger(f"{__name__}.{name}")
@@ -751,7 +649,7 @@ def __init__(
             self._options = options
             self._updated_ledger_condition = asyncio.Condition()  # used to subscribe to metrics updates
 
-            self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor, options)
+            self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor)
 
             if self._ledger.is_finished:
                 self._finished_promise.set_result(None)
@@ -770,7 +668,16 @@ def __init__(
                 # (we get twice from we need to concatenate prepared batches into the accumulator)
                 # TODO: measure.
                 memory=2 * self._options.target_bytes_per_flush,
-            ).remote(current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush)
+            ).remote(current_actor_handle(), cache_dir, source, options, processor)
+
+            self._tokenize_pbar = tqdm(
+                total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard"
+            )
+            self._copy_pbar = tqdm(total=len(source.shard_names), desc=f"{path_for_name}: copying", unit="shard")
+            self._report_totals = _ProgressReport(0, 0, 0)
+            self._copy_report_totals = _ProgressReport(0, 0, 0)
+            self._last_update = time.time()
+
         except Exception:
             # Ray behaves poorly if the constructor of an actor fails, so we catch and log here
             # this also propagates to the finished promise, so we can handle it there
@@ -827,16 +734,26 @@ def _writer_exception(self, shard_name, exc_info: ExceptionInfo):
             pass
         self._do_notify()
 
+    def _child_failed(self, child: ray.actor.ActorHandle | str | None, exception: ExceptionInfo):
+        self._writer_exception(str(child), exception)
+
     def _notify_updated_ledger(self, ledger: CacheLedger):
         """
         Called by the cache writer when it has updated the ledger.
         """
         was_finished = self._ledger.is_finished
-        self._ledger = ledger
+        # ensure the ledger is "monotonic" meaning that we only expect it to grow
+        if ledger.total_num_rows < self._ledger.total_num_rows:
+            raise RuntimeError(f"Ledger went backwards: {ledger.total_num_rows} < {self._ledger.total_num_rows}")
+
+        for shard, rows in ledger.shard_rows.items():
+            if rows < self._ledger.shard_rows.get(shard, 0):
+                raise RuntimeError(f"Shard {shard} went backwards: {rows} < {self._ledger.shard_rows.get(shard, 0)}")
 
         if was_finished:
             raise RuntimeError("Ledger was already finished")
 
+        self._ledger = ledger
         if self._ledger.is_finished:
             logger.info(f"Finalizing cache {self._cache_dir}...")
             # guard against invalid state errors
@@ -854,40 +771,62 @@ async def _do_notify_async():
 
         asyncio.create_task(_do_notify_async())
 
+    def _report_progress(self, report: "_ProgressReport"):
+        import humanfriendly
+
+        if report.new_shards > 0:
+            self._tokenize_pbar.update(report.new_shards)
+        self._report_totals.new_shards += report.new_shards
+        self._report_totals.new_rows += report.new_rows
+        self._report_totals.new_bytes += report.new_bytes
+
+        if time.time() - self._last_update > 10.0:
+            self._last_update = time.time()
+
+            mb_str = humanfriendly.format_size(self._report_totals.new_bytes)
+            self._tokenize_pbar.set_postfix(
+                {
+                    "rows": self._report_totals.new_rows,
+                    "shards": self._report_totals.new_shards,
+                    "size": mb_str,
+                }
+            )
+
+    def _report_copy_progress(self, report: "_ProgressReport"):
+        self._copy_pbar.update(report.new_shards)
+        self._copy_report_totals.new_shards += report.new_shards
+        self._copy_report_totals.new_rows += report.new_rows
+        self._copy_report_totals.new_bytes += report.new_bytes
+
+        if time.time() - self._last_update > 10.0:
+            self._last_update = time.time()
+            self._copy_pbar.set_postfix(
+                {
+                    "shards": report.new_shards,
+                    "rows": report.new_rows,
+                    # "size": humanfriendly.format_size(report.new_bytes),
+                }
+            )
+
 
-def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default(), force_flush=False):
-    name = f"lev_cache_manager::{split}::{cache_dir}"
+def _get_builder_actor(cache_dir, shard_source, processor, options=CacheOptions.default()):
+    name = f"lev_cache_manager::{cache_dir}"
     path_for_name = os.path.join(*os.path.split(cache_dir)[-2:])
     name_for_display = f"builder::{path_for_name}"
 
     return _TreeStoreCacheBuilder.options(name=name, get_if_exists=True).remote(  # type: ignore
         name=name_for_display,
-        split=split,
         cache_dir=cache_dir,
         source=shard_source,
         processor=processor,
         options=options,
-        force_flush=force_flush,
     )
 
 
 #####
 # Core implementation starts below.
 #####
-# The main idea is to have a bunch of reader tasks that read batches, dispatch tokenization tasks, producing
-# a stream of tokenized batches. We then interleave these tokenized batches and write them to the cache.
-# The reader tasks are given a group of shards, which are implicitly concatenated together.
-
-
-@dataclass
-class _Batch:
-    """
-    A batch of data that has either been read or tokenized.
-    """
-
-    shard_name: str
-    row_indices: List[int]
-    payload: ray.ObjectRef
+# The main idea is to tokenize each shard group in parallel, and then write the results to the cache in order.
 
 
 @dataclass
@@ -898,33 +837,23 @@ class _ShardFinished:
 
     shard_name: str
     total_rows: int
+    path_to_shard: str
 
 
-_Message = _Batch | _ShardFinished
-"""
-A message that can be sent from a reader task to the writer task.
-"""
-
-_TIME_BETWEEN_WRITES = 20.0  # seconds
-
-
-@ray.remote(num_cpus=1)
+@ray.remote(num_cpus=1, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"}))
 def _core_writer_task(
     parent,
     cache_dir,
-    initial_ledger: CacheLedger,
     source: ShardedDataSource,
+    options: CacheOptions,
     processor,
-    force_flush: bool,
 ):
     """
     This is the main task that processes the data and writes it to the cache.
 
-    It chains together:
-        * 1 generator per shard group
-        * interleaving of the generators
-        * processing of the batches
-        * writing of the batches to the cache
+    It receives "finished shards" messages from the reader tasks, and copies the data from temporary files
+    to the cache directory.
+
     """
     pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT)
     logger.info("Starting writer task")
@@ -933,400 +862,620 @@ def _core_writer_task(
     # append a small random number to the name to avoid collisions
     name += f"::{random.randint(0, 1000)}"
 
+    # we want to do the following:
+    # 1. write the 0th shard group to the output cache directly, updating metrics as we go
+    # 2. in the background, start processing other shard groups to temporary caches
+    # 3. once (1) is done, we start copying the temporary caches to the output cache (in order)
+
+    # We notify the parent actor of progress and updates to the ledger.
+    # We special-case the 0'th ledger because we commit it to the output cache directly.
+    def report_fn(report: _ProgressReport, ledger: CacheLedger):
+        parent._report_progress.remote(report)
+
+    def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
+        parent._report_progress.remote(report)
+        ray.get(parent._notify_updated_ledger.remote(ledger))
+
     with log_failures_to(parent):
+        temporary_cache_path = os.path.join(cache_dir, "___temp")
 
-        def on_write(ledger):
+        group_cache_paths: dict[str, str] = {}
+        group_ledgers: dict[str, CacheLedger | None] = {}
+        write_refs: dict[str, ray.ObjectRef] = {}
+
+        if len(source.shard_names) == 0:
+            logger.info("No shards to process. Writing empty ledger.")
+            ledger = CacheLedger.load_or_initialize(cache_dir, source, processor)
+            ledger.is_finished = True
+            ledger._serialize_and_commit(cache_dir)
             ray.get(parent._notify_updated_ledger.remote(ledger))
+            return
+
+        shard_groups = _assign_shards_to_groups(source, options.num_shard_groups)
+
+        for name, group in shard_groups.items():
+            assert len(group) > 0
 
-        sharded_cache_writer = ShardedCacheWriter(
-            cache_dir, initial_ledger, processor.output_exemplar, on_write=on_write
+        logger.debug(
+            f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}."
         )
 
-        options = initial_ledger.metadata.options
-        num_groups = min(options.num_shard_groups or 1000000, len(source.shard_names))
+        processor_ref = ray.put(processor)
+        source_ref = ray.put(source)
 
-        processor_pool = _mk_processor_pool(processor, 0, num_groups * 4)
+        # We treat the first group specially: we tokenize it directly to the output cache (since it comes first)
+        # This enables us to expose data quickly
+        first_group = next(iter(shard_groups), None)
 
-        interleave: RayPrefetchQueue = RayPrefetchQueue(
-            lambda: _make_interleave(name, source, initial_ledger, processor_pool),
-            64,
-            producer_options={"num_cpus": 1, "name": f"{name}::interleave"},
+        for group_name, shards in shard_groups.items():
+            if group_name == first_group:
+                group_out_path = cache_dir
+            else:
+                group_out_path = os.path.join(temporary_cache_path, group_name)
+
+            group_cache_paths[group_name] = group_out_path
+
+            ledger = _try_load(group_out_path)
+            group_ledgers[group_name] = ledger
+
+            if ledger is not None:
+                if group_name == first_group:
+                    ray.get(parent._notify_updated_ledger.remote(ledger))
+                continue
+
+            report_fn_to_use = report_fn_first_group if group_name == first_group else report_fn
+
+            ref = (
+                ray.remote(_tokenize_one_shard_group)
+                .options(  # type: ignore
+                    num_cpus=processor.num_cpus,
+                    num_gpus=processor.num_gpus,
+                    resources=processor.resources,
+                    memory=3 * 1024 * 1024 * 1024,  # made this up
+                    name=f"tokenize::{temporary_cache_path}::{group_name}",
+                    retry_exceptions=True,
+                    max_retries=10,
+                )
+                .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn_to_use, parent)
+            )
+
+            write_refs[group_name] = ref
+
+        ledger = _start_copies(
+            parent,
+            cache_dir,
+            shard_groups,
+            first_group,
+            write_refs,
+            group_ledgers,
+            group_cache_paths,
+            processor,
+            processor_ref,
         )
 
-        total_time = Stopwatch()
-        loading_time = Stopwatch()
-        append_time = Stopwatch()
-        flush_time = Stopwatch()
-        flush_amortized_time = Stopwatch()
-
-        current_prepared_batch: Optional[PyTree[PreparedBatch]] = None
-        current_shard_rows: dict[str, int] = {}
-        time_of_last_write = time.time()
-        batches_total = 0.0
-        flush_thread = None
-        finished_shards_last_flush: list = []
-
-        while True:
-            with total_time:  # 0.0051
-                try:
-                    cur_time = time.time()
-                    time_since_last_write = cur_time - time_of_last_write
-                    remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write
-
-                    if current_prepared_batch is not None:
-                        with flush_amortized_time:  # 6e-4
-                            current_byte_size = sum(
-                                b.byte_size for b in jax.tree_util.tree_flatten(current_prepared_batch)[0]
-                            )
-                            should_flush = (
-                                force_flush
-                                or remaining_time <= 0
-                                or (current_byte_size >= options.target_bytes_per_flush)
-                            )
-                            if should_flush:
-                                with flush_time:  # 0.613s
-                                    if flush_thread is not None:
-                                        flush_thread.join()
-
-                                    flush_thread = ExceptionTrackingThread(
-                                        target=_write_batches,
-                                        args=(
-                                            sharded_cache_writer,
-                                            current_shard_rows,
-                                            current_prepared_batch,
-                                            finished_shards_last_flush,
-                                        ),
-                                    )
-                                    flush_thread.start()
-
-                                    current_prepared_batch = None
-                                    current_shard_rows = {}
-                                    finished_shards_last_flush = []
-
-                                    time_of_last_write = time.time()
-                                    continue
-                    else:
-                        remaining_time = _TIME_BETWEEN_WRITES
-
-                    with loading_time:
-                        try:
-                            message = interleave.get_next(timeout=max(remaining_time, 0.1))
-                        except QueueEmpty:
-                            logger.info("Writer running ahead of reader.")
-                            continue
-
-                    with append_time:
-                        match message:
-                            case _Batch(shard, row_indices, payload):
-                                batches_total += 1
-                                this_prepared_batch = ray.get(payload)
-                                if current_prepared_batch is None:
-                                    # TODO: actually check row indices
-                                    current_shard_rows = {shard: len(row_indices)}
-                                    current_prepared_batch = this_prepared_batch
-                                else:
-                                    current_shard_rows[shard] = current_shard_rows.get(shard, 0) + len(row_indices)
-                                    current_prepared_batch = _concat_prepared_batches(
-                                        current_prepared_batch, this_prepared_batch
-                                    )
-                                    del this_prepared_batch
-
-                                if force_flush:
-                                    _write_batches(
-                                        sharded_cache_writer,
-                                        current_shard_rows,
-                                        current_prepared_batch,
-                                        finished_shards_last_flush,
-                                    )
-                                    finished_shards_last_flush = []
-                                    current_prepared_batch = None
-                                    current_shard_rows = {}
-
-                            case _ShardFinished(shard, total_rows):
-                                finished_shards_last_flush.append((shard, total_rows))
-                            case _:
-                                raise AssertionError(f"Unexpected message type {type(message)}")
-
-                    # if batches_total % 1000 == 0:
-                    #     print(
-                    #         f"Processed {batches_total} batches: {loading_time.average()}s load,"
-                    #         f" {append_time.average()}s append, {flush_time.average()}s flush blocked, "
-                    #         f"{flush_amortized_time.average()}s amortized flush, "
-                    #         f"{total_time.average()}s total"
-                    #     )
-                except StopIteration:
-                    logger.info("Finished all shards")
-                    break
-                except Exception as e:
-                    logger.exception("Error while processing batch")
-                    raise e
+        ledger.is_finished = True
+        ledger._serialize_and_commit(cache_dir)
+        ray.get(parent._notify_updated_ledger.remote(ledger))
+
+        temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir}
+        _clean_up_temp_caches(temporary_cache_paths)
+
+
+def _start_copies(
+    parent,
+    cache_dir,
+    shard_groups,
+    first_group,
+    write_refs,
+    group_ledgers,
+    group_cache_paths,
+    processor,
+    processor_ref,
+):
+    """
+    Copy the temporary caches to the output cache, in order. (essentially concatenating them)
 
-        # force a flush
-        if current_prepared_batch is not None or finished_shards_last_flush:
-            if flush_thread is not None:
-                flush_thread.join()
-            _write_batches(
-                sharded_cache_writer, current_shard_rows, current_prepared_batch, finished_shards_last_flush
+    Args:
+        parent: the parent actor handle (_TreeStoreCacheBuilder)
+        cache_dir: the output cache directory
+        shard_groups: a dict mapping group names to lists of shard names
+        first_group: the privileged group that is written directly to the output cache
+        write_refs: a dict mapping group names to ray.ObjectRefs of the cache building tasks
+        group_ledgers: a dict mapping group names to the ledgers for the groups. Mutated in place.
+        group_cache_paths: a dict mapping group names to the paths of the temporary caches
+        processor: the processor object
+        processor_ref: a ray.ObjectRef of the processor object
+
+    Returns:
+        The final ledger
+    """
+    # This logic is a bit hairy thanks to resumes.
+    # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these
+    # separately. We also need to update the ledger as we go.
+    # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size.
+    # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset.
+    # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets.
+
+    # * When we load the permanent cache, we have already written some number of groups to it. In
+    #   particular, we have written the 0'th group to the permanent cache.
+    # * We enforce that we only commit a whole group to the ledger at a time.
+    # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go.
+    # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets"
+    #   for the data in the cache. We can get the total number of rows from the ledger, and we also calculate
+    #   the data offsets for where the group goes in the permanent cache. This is just a running sum of the
+    #   data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree
+    #   of integers, one for each array.
+    # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache
+    #   by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache.
+    # * We also need to update the ledger with the total number of rows
+
+    # reload the ledger for the first group, which will be the sink for the other groups
+    assert first_group in write_refs
+
+    group_ledgers[first_group] = ray.get(write_refs[first_group])
+    overall_ledger = group_ledgers[first_group]
+
+    # initialize the data offset tree
+    permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False)
+    data_offset_tree = jax.tree_map(lambda x: x.data_size, permanent_cache.tree)
+    total_rows_from_caches = overall_ledger.total_num_rows
+    copy_refs: dict[str, ray.ObjectRef] = {}
+    last_ref: ray.ObjectRef | None = None
+
+    found_one_to_copy = False
+
+    for group in shard_groups:
+        # first make sure it's either done this run or already done
+        if write_refs.get(group) is not None:
+            this_ledger = ray.get(write_refs[group])
+            group_ledgers[group] = this_ledger
+        else:
+            this_ledger = group_ledgers[group]
+
+        if group == first_group:
+            # this is the first group, so it's already in the cache and we don't need to
+            # increment the data offset tree etc.
+            parent._report_copy_progress.remote(
+                _ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows)
             )
+            continue
 
-        sharded_cache_writer.finish()
+        assert this_ledger is not None
+        # see if we already copied this group, meaning all the shards are in the permanent cache
+        shards_copied = sum(1 if shard in overall_ledger.finished_shards else 0 for shard in shard_groups[group])
+
+        if found_one_to_copy and shards_copied > 0:
+            raise RuntimeError("A previous group was copied, but this group was not. This should never happen.")
+        elif shards_copied == len(shard_groups[group]):
+            assert (
+                overall_ledger.total_num_rows >= total_rows_from_caches
+            ), f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}"
+            continue  # nothing to do
+        elif shards_copied > 0:
+            # In theory we can handle this, but it's a bit tricky, so we're going to punt for now
+            raise RuntimeError("Some shards were copied but not all. This should never happen.")
+
+        found_one_to_copy = True
+        # we need to copy this group
+
+        # we can't "commit" the group to the ledger (or the number of rows)
+        # until we've updated the ledger for all previous groups, so we block on the last ref
+        ref_to_send = None if last_ref is None else RefBox(last_ref)
+
+        last_ref = _copy_cache.remote(
+            cache_dir,
+            group_cache_paths[group],
+            processor_ref,
+            data_offset_tree,
+            ref_to_send,
+            total_rows_from_caches,
+            parent,
+        )
+        copy_refs[group] = last_ref
 
-        out = sharded_cache_writer.get_ledger()
-        return out
+        # update the offset information: data offsets and total rows
+        this_cache = TreeStore.open(processor.output_exemplar, group_cache_paths[group], mode="r", cache_metadata=True)
+        data_offset_tree = jax.tree.map(
+            operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree)
+        )
+        total_rows_from_caches += this_ledger.total_num_rows
 
+    # refs form a linked list implicitly, so we can just wait on the last one
+    if last_ref is not None:
+        ledger = ray.get(last_ref)
+    else:
+        ledger = overall_ledger
+    return ledger
 
-def _concat_prepared_batches(
-    current_prepared_batch: PyTree[PreparedBatch], this_prepared_batch: PyTree[PreparedBatch]
-):
-    return jax.tree.map(lambda *bs: PreparedBatch.concat(bs), current_prepared_batch, this_prepared_batch)
+
+def _clean_up_temp_caches(paths):
+    for path in paths:
+        if fsspec_exists(path):
+            for i in range(10):
+                # this is crashy for some reason
+                try:
+                    fsspec_remove(path, recursive=True)
+                    break
+                except Exception:
+                    logger.exception(f"Failed to remove {path} on attempt {i}")
+                    time.sleep(1)
 
 
-def _write_batches(writer: ShardedCacheWriter, shard_totals, batch: Optional[PyTree[PreparedBatch]], finished_shards):
-    # concatenate the payloads
-    if batch is not None:
-        writer.write_prepared_batch(shard_totals, batch)
+def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]:
+    if num_groups is None or num_groups >= len(source.shard_names):
+        return {shard_name: [shard_name] for shard_name in source.shard_names}
 
-    for shard, total_rows in finished_shards:
-        writer.finish_shard(shard, total_rows)
+    shard_names = source.shard_names
+    num_shards_per_group = (len(shard_names)) // num_groups
+    num_groups_with_extra = len(shard_names) % num_groups
 
+    # if we have a remainder, we want to distribute the extra shards evenly
+    out_groups: dict[str, list[str]] = {}
+    start = 0
+    for i in range(num_groups):
+        num_shards = num_shards_per_group + (1 if i < num_groups_with_extra else 0)
+        out_groups[f"group_{i}"] = list(shard_names[start : start + num_shards])
+        start += num_shards
 
-def _fetch_batches(batches) -> tuple[dict[str, int], list[PreparedBatch]]:
-    shards_for_batches, payloads_for_batches = zip(*batches)
-    payloads_for_batches = ray.get(list(payloads_for_batches))
+    # make sure we got all the shards
+    assert sum(len(shards) for shards in out_groups.values()) == len(shard_names)
 
-    shard_row_totals: dict[str, int] = {}
-    for shard, payload in zip(shards_for_batches, payloads_for_batches):
-        shard_row_totals[shard] = shard_row_totals.get(shard, 0) + jax.tree.leaves(payload)[0].num_rows
+    return out_groups  # type: ignore
 
-    return shard_row_totals, payloads_for_batches
 
+def _merge_ledgers(dest, source):
+    dest.total_num_rows += source.total_num_rows
+    for shard, rows in source.shard_rows.items():
+        current_value = dest.shard_rows.get(shard, 0)
+        assert current_value == 0, f"Shard {shard} already has {current_value} rows"
+        dest.shard_rows[shard] = rows
 
-def _interleave_shards(readers: Sequence[RayPrefetchQueue], first_index: int) -> Iterator[T]:  # _Message
+    dest.finished_shards.extend(source.finished_shards)
+    for field, count in source.field_counts.items():
+        dest.field_counts[field] = dest.field_counts.get(field, 0) + count
+
+    return dest
+
+
+@ray.remote(num_cpus=4, memory=4 * 1024 * 1024 * 1024)
+def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: RefBox, rows_so_far, parent):
     """
-    Interleaves the results of multiple iterators. To support resume,
-    we need to be able to start from not the "first" iterator.
+    Copies the data from one cache to another, appending it to the end of the destination cache.
 
+    Once the copy is done and the last_ref is set, the data is "unlocked" in the destination cache by updating the
+    offsets[0] of the destination cache to the total number of rows in the cache.
     Args:
-        readers: A list of iterators
-        first_index: The index of the first iterator to start from. We use this to support resuming.
-    """
+        dest_path:  The path to the destination cache.
+        source_path: The path to the source cache.
+        processor: The processor used to create the cache.
+        data_offset_tree: The data offset tree for the destination cache.
+        last_ref: The ref to wait on before updating the ledger.
+        rows_so_far: The total number of rows in the destination cache before this copy.
 
-    finished: set[int] = set()
-    total = 0
-    while len(finished) < len(readers):
-        for i in range(first_index, len(readers)):
-            reader = readers[i]
-            if i not in finished:
-                try:
-                    message = reader.get_next()
-                    total += 1
-                    yield message
-                except StopIteration:
-                    finished.add(i)
-                except Exception as e:
-                    logger.exception(f"Error while processing group {i}")
-                    raise e
+    Returns:
 
-        first_index = 0
+    """
+    with log_failures_to(parent):
+        asyncio.run(_extend_cache_with_other_cache(dest_path, source_path, processor, data_offset_tree, rows_so_far))
+        if last_ref is not None:
+            ray.wait([last_ref.ref], fetch_local=False)
+        permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False)
+        source_ledger = CacheLedger.load(source_path)
+
+        new_num_rows = source_ledger.total_num_rows + rows_so_far
+
+        futures = jax.tree.leaves(jax.tree.map(lambda x: x.offsets[0].write(new_num_rows), permanent_cache.tree))
+        for future in futures:
+            future.result()
+
+        dest_ledger = CacheLedger.load(dest_path)
+        _merge_ledgers(dest_ledger, source_ledger)
+        dest_ledger._serialize_and_commit(dest_path)
+        assert not dest_ledger.is_finished
+
+        ray.get(parent._notify_updated_ledger.remote(dest_ledger))
+        parent._report_copy_progress.remote(
+            _ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows)
+        )
 
-    logger.info(f"Finished all shards, got {total} batches")
+        return dest_ledger
 
 
-def _assign_shards_to_groups(shards: Sequence[_ShardStatus], num_groups: int) -> list["_ShardGroup"]:
+async def _extend_cache_with_other_cache(
+    dest_path: str, source_path: str, processor: BatchProcessor, data_offset_tree: PyTree[int], row_offset
+) -> int:
     """
-    Assigns shards to groups in a round-robin fashion.
+    Copies the data from one cache to another, appending it to the end of the destination cache.
+
+    Returns:
+        The number of rows in the source cache.
     """
-    groups: list[list] = [[] for _ in range(num_groups)]
-    for i, shard in enumerate(shards):
-        groups[i % num_groups].append(shard)
-    return [_ShardGroup(group) for group in groups]
+    logger.info(f"Copying data from {source_path} to {dest_path}.")
+    dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False)
+    source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True)
+
+    source_num_rows = await source.async_len()
+
+    async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int):
+        """Copies **just the data array** from one shard to the permanent cache at a given offset."""
+        # TODO: it'd be good if we just didn't expose the full data array (but only the used part)
+        data_size = source_array.data_size
+        data = source_array.data[0:data_size]
+        futures: list[ts.Future] = []
+
+        # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data)
+        async with ts.Transaction() as txn:
+            dest = dest_array.data
+            out_end = data_offset + data_size
+            write_future = dest.with_transaction(txn)[data_offset:out_end].write(data)
+            futures.append(write_future)
+
+        if source_array.shapes is not None:
+            source_shapes = source_array.shapes[0:source_num_rows]
+            async with ts.Transaction() as txn:
+                dest = dest_array.shapes
+                out_end = row_offset + source_num_rows
+                shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes)
+                futures.append(shape_future)
+
+        source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]]
+        source_offsets = _virtual_offset(source_offsets, data_offset)
+
+        async with ts.Transaction() as txn:
+            dest = dest_array.offsets
+            out_end = row_offset + 1 + source_num_rows
+            offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets)
+
+        futures.append(offset_future)
+
+        out = await asyncio.gather(*futures)
+        return out
 
+    futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree)
 
-def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]:
-    prng = random.Random(seed)
-    shuffled = list(shards)
-    prng.shuffle(shuffled)
-    return shuffled
+    await asyncio.gather(*jax.tree.leaves(futures))
+    logger.info(f"Finished copying data from {source_path} to {dest_path}.")
 
+    return source_num_rows
 
-class _ShardGroup:
-    """
-    Given a group of shards and a list of statuses, implicitly concatenates the shards and reads from them.
 
-    This class mostly exists for resuming: we want to be able to start from the last shard we were working on.
+def _virtual_offset(base: ts.TensorStore, offset_amount):
+    """
+    This function creates a new tensorstore that is a virtual offset of another tensorstore.
+    That is, it's y[i] = x[i] + offset_amount.
     """
 
-    def __init__(self, group: list[_ShardStatus]):
-        self.shards = group
-        self.total_rows_committed, _all_finished = self._impute_total_rows_committed_and_check_invariants()
-
-    def _impute_total_rows_committed_and_check_invariants(self):
-        # we also want to ensure that we haven't started any shards until we've finished the previous ones
-        total_committed = 0
-        last_shard_name = None
-        last_was_finished = True
-        all_finished = True
-
-        for status in self.shards:
-            shard_name = status.shard_name
-            if not last_was_finished and status.num_rows_committed > 0:
-                raise ValueError(
-                    f"Shard {shard_name} has rows committed but previous shard in group {last_shard_name} "
-                    "is not finished. Something about the cache configuration has changed: either the "
-                    "number/order of shards, the shard shuffle random seed, or the number of groups."
-                )
-            total_committed += status.num_rows_committed
-            if not status.is_finished:
-                all_finished = False
-            last_was_finished = status.is_finished
-            last_shard_name = shard_name
+    async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.VirtualChunkedReadParameters):
+        array[...] = (await base[domain].read()) + offset_amount
 
-        return total_committed, all_finished
+    return ts.virtual_chunked(do_read, dtype=base.dtype, domain=base.domain, shape=base.shape)
 
 
-def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: CacheLedger, processor_pool: ActorHandle):
-    """
-    Given a list of ShardStatus objects and sources, creates an interleaving generator
-    that reads from shards and tokenizes them in parallel.
+async def _copy_data_from_one_shard_to_permanent_memory(
+    dest_path: str,
+    source_path: str,
+    processor: BatchProcessor,
+    data_offset_tree: PyTree[int],
+):
+    """Copies from one tree store to the permanent cache at a given offset (for each leaf)"""
+    logger.info(f"Copying data from {source_path} to {dest_path}.")
+    dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False)
+    source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True)
 
-    We use ShardStatus objects to track the progress of each shard. If we're preempted, we can resume
-    from the last shard we were working on. This function starts each shard at the last committed row
-    and starts interleaving from the next shard (i.e. the one with the fewest rows that isn't finished).
-    """
-    logger.setLevel(DEFAULT_LOG_LEVEL)
-    statuses = _get_shard_statuses(initial_ledger, source)
+    def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int):
+        # TODO: it'd be good if we just didn't expose the full data array (but only the used part)
+        data = source_array.data[0 : source_array.data_size]
+        # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data)
+        with ts.Transaction() as txn:
+            dest = dest_array.data
+            out_end = data_offset + source_array.data_size
+            write_future = dest.with_transaction(txn)[data_offset:out_end].write(data)
 
-    options = initial_ledger.metadata.options
+        return write_future
 
-    unfinished_shards = _check_current_shard_progress(statuses)
+    futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree)
 
-    if not unfinished_shards:
-        logger.info("All shards finished. Nothing to do.")
-        return
+    await asyncio.gather(*jax.tree.leaves(futures))
+    logger.info(f"Finished copying data from {source_path} to {dest_path}.")
+    return
 
-    group_names, groups = _randomize_and_group_shards(name, options, statuses)
 
-    logger.warning(f"Starting cache build with {len(statuses)} shards, in {len(groups)} groups")
+@dataclass
+class _ProgressReport:
+    new_rows: int = 0
+    new_bytes: float = 0
+    new_shards: int = 0
+    # TODO: other counts
 
-    def _make_generator_fn(group: _ShardGroup):
-        def generator():
-            pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT)
-            for message in _shard_reader_generator(source, group, options.batch_size):
-                match message:
-                    case _Batch():
-                        # processed = ray.put(process_task(ray.get(message.payload)))
-                        # processed = process_task.remote(processor_ref, message.payload)
-                        processed = processor_pool.process_batch.remote(RefBox(message.payload))
-                        yield dataclasses.replace(message, payload=processed)
-                    case _ShardFinished():
-                        yield message
-                    case _:
-                        raise AssertionError(f"Unexpected message type {type(message)}")
 
-        return generator
+def _tokenize_one_shard_group(
+    temporary_cache_path: str,
+    source: ShardedDataSource,
+    shards: list[str],
+    processor: BatchProcessor,
+    options: CacheOptions,
+    report_fn: Callable[[_ProgressReport, CacheLedger], None],
+    force_unfinalized: bool,
+) -> CacheLedger:
+    # ray breaks if this is top level
+    import humanfriendly
 
-    generator_fns = [_make_generator_fn(group) for group in groups]
+    logger = pylogging.getLogger("tokenize")
+    pylogging.basicConfig(level=pylogging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
-    readers = [
-        RayPrefetchQueue(
-            fn,
-            options.prefetch_per_group,
-            producer_options=dict(num_cpus=0.1, name=name, scheduling_strategy="SPREAD"),
-        )
-        for name, fn in zip(group_names, generator_fns)
-    ]
+    # restrict shards to the ones we're supposed to process
+    # this is a bit hacky but when there are a lot of shards (e.g. SlimPajama 122K),
+    # we encounter significant overhead just parsing the shard names from the json
+    source = _RestrictedShardedDataSource(source, shards)
 
-    # then figure out the first shard to start from. This is the first unfinished shard with the minimum number of rows
-    first_group_to_start = min(
-        range(len(groups)),
-        key=lambda i: groups[i].total_rows_committed,
-    )
+    ledger = CacheLedger.load_or_initialize(temporary_cache_path, source, processor)
 
-    yield from _interleave_shards(readers, first_group_to_start)
+    if ledger.is_finished:
+        logger.info("Shard group already processed.")
+        return ledger
 
+    writer = ShardGroupCacheWriter(temporary_cache_path, ledger, shards, processor.output_exemplar)
 
-def _mk_processor_pool(processor, min_size, max_size):
-    import hashlib
+    total_rows = ledger.total_num_rows
+    found_shard_with_rows = False
 
-    metadata_hash = hashlib.md5(str(processor.metadata).encode()).hexdigest()
-    processor_pool_name = f"processor_pool::{metadata_hash}"
-    processor_pool = BatchProcessorPool.options(  # type: ignore
-        name=processor_pool_name, get_if_exists=True, lifetime="detached"
-    ).remote(  # type: ignore
-        processor, min_size, max_size
-    )
+    if total_rows > 0:
+        report_fn(_ProgressReport(new_rows=total_rows), ledger)
 
-    ray.get(processor_pool.ensure_max_at_least.remote(max_size))
+    for shard_name in shards:
+        if shard_name in ledger.finished_shards:
+            logger.info(f"Shard {shard_name} already processed.")
+            report_fn(_ProgressReport(new_shards=1), ledger)
+            continue
 
-    return processor_pool
+        logger.debug(f"Processing {shard_name}.")
 
+        rows_this_shard = ledger.shard_rows.get(shard_name, 0)
 
-def _check_current_shard_progress(statuses):
-    unfinished_shards: list[_ShardStatus] = []
-    shards_with_progress: dict[str, int] = {}
-    for status in statuses:
-        if not status.is_finished:
-            unfinished_shards.append(status)
-        if status.num_rows_committed > 0:
-            shards_with_progress[status.shard_name] = status.num_rows_committed
-    if unfinished_shards and shards_with_progress:
-        formatted = ", ".join(f"{k}: {v}" for k, v in shards_with_progress.items())
-        logger.info(f"Resuming from shards with progress: {formatted}")
-    return unfinished_shards
+        if found_shard_with_rows and rows_this_shard != 0:
+            raise ValueError("Found more than one shard with rows to process.")
 
+        if rows_this_shard != 0:
+            found_shard_with_rows = True
 
-def _randomize_and_group_shards(name, options, statuses):
-    if options.shard_order_randomization_key is not None:
-        seed = options.shard_order_randomization_key
-        logger.info(f"Randomizing shard order with seed {seed}")
-        statuses = _randomize_shards(statuses, seed)
+        shard_iterator = source.open_shard_at_row(shard_name, rows_this_shard)
 
-    num_groups = min(
-        options.num_shard_groups if options.num_shard_groups is not None else len(statuses), len(statuses)
-    )
-    if num_groups == 1:
-        group_names = [f"generator::{name}::all_shards"]
-    elif len(statuses) == num_groups:
-        group_names = [f"generator::{name}::{status.shard_name}" for status in statuses]
-    else:
-        group_names = [f"generator::{name}::group_{i}" for i in range(num_groups)]
+        prepared_batch: PyTree[PreparedBatch] | None = None
+        this_batch_size = 0
+
+        for batch in batched(shard_iterator, options.batch_size):
+            tokenized = processor(batch)
+            tokenized = _canonicalize_batch(tokenized)  # type: ignore
+            this_prepared = writer._tree_store.batch_preparer(tokenized)
+
+            this_batch_size += len(batch)
+            rows_this_shard += len(batch)
+            total_rows += len(batch)
+
+            if prepared_batch is None:
+                prepared_batch = this_prepared
+            else:
+                prepared_batch = jax.tree.map(
+                    lambda *trees: PreparedBatch.concat(trees), prepared_batch, this_prepared
+                )
+
+            batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch))
+
+            if batch_byte_size > options.target_bytes_per_flush:
+                writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch)
+                report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger)
+
+                nice_bytes = humanfriendly.format_size(batch_byte_size)
+                logger.debug(
+                    f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})"
+                )
+                # print(f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})", flush=True)
+                this_batch_size = 0
+                prepared_batch = None
+
+        if prepared_batch is not None:
+            batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch))
+            nice_bytes = humanfriendly.format_size(batch_byte_size)
+
+            report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger)
 
-    groups = _assign_shards_to_groups(statuses, num_groups)
-    return group_names, groups
+            writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch)
+            logger.debug(
+                f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})"
+            )
+            this_batch_size = 0
+            prepared_batch = None
+
+        writer.finish_shard(shard_name, rows_this_shard)
+
+        report_fn(_ProgressReport(new_shards=1), writer.ledger)
+
+    if not force_unfinalized:
+        writer.finish()
 
+    logger.debug(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.")
 
-def _shard_reader_generator(
-    shard_source: ShardedDataSource[T], group: _ShardGroup, batch_size: int
-) -> Iterator[_Message]:
+    return writer.ledger
+
+
+class ShardGroupCacheWriter:
     """
-    Given a group of shards, implicitly concatenates the shards and reads from them.
+    Similar to SerialCacheWriter, but tracks shard metadata for one shard.
     """
-    for status in group.shards:
-        if status.is_finished:
-            logger.info(f"Skipping finished shard {status.shard_name}")
-            continue
-        start_row = status.num_rows_committed
-        logger.info(f"Opening shard {status.shard_name} at row {start_row}")
-        shard_iter = shard_source.open_shard_at_row(status.shard_name, start_row)
 
-        batch = []
-        batch_idxes = []
-        row_idx = start_row
-        for row in shard_iter:
-            batch.append(row)
-            batch_idxes.append(row_idx)
-            row_idx += 1
+    def __init__(self, cache_dir: str, initial_ledger: CacheLedger, shards: list[str], exemplar: T):
+        self.cache_dir = cache_dir
+
+        self._ledger = copy.deepcopy(initial_ledger)
+        self.shards = shards
+
+        self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a")  # type: ignore
+        self._tree_store.trim_to_size(self._ledger.total_num_rows)
+
+    @property
+    def ledger(self):
+        return self._ledger
+
+    # we have both versions b/c we need this one for actors
+    def get_ledger(self):
+        return self._ledger
 
-            if len(batch) == batch_size:
-                yield _Batch(status.shard_name, batch_idxes, ray.put(batch))
-                batch = []
-                batch_idxes = []
+    @property
+    def is_finished(self):
+        return self._ledger.is_finished
 
-        if len(batch) > 0:
-            yield _Batch(status.shard_name, batch_idxes, ray.put(batch))
+    def finish_shard(self, shard_name: str, num_rows: int):
+        if shard_name not in self.shards:
+            raise ValueError(f"Shard {shard_name} not in tracked shards")
 
-        logger.info(f"Finished generating shard {status.shard_name} with {row_idx} rows")
-        yield _ShardFinished(status.shard_name, row_idx)
+        current_rows = self._ledger.shard_rows.get(shard_name, 0)
+        if current_rows != num_rows:
+            raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}")
+
+        self._ledger.finished_shards.append(shard_name)
+        self._ledger._serialize_and_commit(self.cache_dir)
+
+    def write_prepared_batch(self, shard_name: str, row_count: int, batch: PyTree[PreparedBatch]):
+        if self.is_finished:
+            raise RuntimeError("Cannot write to a finished cache")
+        self._tree_store.extend_with_batch(batch)
+
+        if shard_name not in self.shards:
+            raise ValueError(f"Shard {shard_name} not in tracked shards")
+        self._ledger.shard_rows[shard_name] += row_count
+        self._ledger.total_num_rows += row_count
+
+        self._ledger._serialize_and_commit(self.cache_dir)
+
+    def finish(self):
+        if len(self._ledger.finished_shards) != len(self.shards):
+            raise ValueError("Not all shards are finished")
+
+        self._ledger.is_finished = True
+        self._ledger._serialize_and_commit(self.cache_dir)
+        # ensure all tracked shards are finished
+
+        return self._tree_store
+
+
+class _RestrictedShardedDataSource(ShardedDataSource):
+    def __init__(self, source: ShardedDataSource, shards: list[str]):
+        self._source = source
+        self._shards = shards
+
+    @property
+    def shard_names(self):
+        return self._shards
+
+    def open_shard_at_row(self, shard_name, row):
+        return self._source.open_shard_at_row(shard_name, row)
+
+
+def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]:
+    prng = random.Random(seed)
+    shuffled = list(shards)
+    prng.shuffle(shuffled)
+    return shuffled
 
 
 def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]:
@@ -1360,8 +1509,13 @@ def _ledger_to_metrics(ledger: CacheLedger) -> InProgressCacheMetrics:
     )
 
 
-def _get_shard_statuses(ledger: CacheLedger, source: ShardedDataSource):
-    return [
-        _ShardStatus(name, ledger.shard_rows.get(name, 0), name in ledger.finished_shards)
-        for name in source.shard_names
-    ]
+def _try_load(path):
+    try:
+        ledger = CacheLedger.load(path)
+        if ledger.is_finished:
+            return ledger
+        else:
+            logger.debug(f"Cache exists but is not finished at {path}.")
+            return None
+    except FileNotFoundError:
+        return None
diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py
index 03355a8d2..83d6c88b0 100644
--- a/src/levanter/store/tree_store.py
+++ b/src/levanter/store/tree_store.py
@@ -172,6 +172,9 @@ def get_batch_sync(self, indices) -> List[T]:
 
         return out
 
+    async def async_len(self) -> int:
+        return await jax.tree.leaves(self.tree)[0].num_rows_async()
+
 
 def _construct_builder_tree(exemplar, path, mode, cache_metadata):
     def open_builder(tree_path, item):
diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py
index 3973c025a..c7c1a5285 100644
--- a/src/levanter/trainer.py
+++ b/src/levanter/trainer.py
@@ -498,7 +498,8 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwa
         grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
         mbs = self.config.microbatch_size
         grad_fn = microbatched(grad_fn, self.TrainBatch, mbs, self.parameter_axis_mapping, self.compute_axis_mapping)
-        return grad_fn(model, *batch, **batch_kwargs)
+        with hax.axis_mapping(self.compute_axis_mapping):
+            return grad_fn(model, *batch, **batch_kwargs)
 
 
 def _initialize_global_tracker(config, run_id):
diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py
index 64870443d..c8d3931fe 100644
--- a/src/levanter/utils/fsspec_utils.py
+++ b/src/levanter/utils/fsspec_utils.py
@@ -1,5 +1,6 @@
 import braceexpand
 import fsspec
+from fsspec.asyn import AsyncFileSystem
 
 
 def exists(url, **kwargs) -> bool:
@@ -14,7 +15,7 @@ def mkdirs(path):
     fs.makedirs(path, exist_ok=True)
 
 
-def fsspec_expand_glob(url):
+def expand_glob(url):
     expanded_urls = braceexpand.braceexpand(url)
     for expanded_url in expanded_urls:
         if "*" in expanded_url:
@@ -28,3 +29,21 @@ def fsspec_expand_glob(url):
                 yield from [f"{protocol}://{path}" for path in globbed]
         else:
             yield expanded_url
+
+
+def remove(url, *, recursive=False, **kwargs):
+    """Remove a file from a remote filesystem."""
+    # TODO: better to use a STS deletion policy or job for this one.
+    fs, path = fsspec.core.url_to_fs(url, **kwargs)
+
+    fs.rm(path, recursive=recursive)
+
+
+async def async_remove(url, *, recursive=False, **kwargs):
+    """Remove a file from a remote filesystem."""
+    fs, path = fsspec.core.url_to_fs(url, **kwargs)
+
+    if isinstance(fs, AsyncFileSystem):
+        return await fs._rm(path, recursive=recursive)
+    else:
+        fs.rm(path, recursive=recursive)
diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py
index 7a5475738..a0002b1c1 100644
--- a/tests/test_hf_gpt2_serialize.py
+++ b/tests/test_hf_gpt2_serialize.py
@@ -19,7 +19,7 @@
 from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef
 from levanter.models.attention import AttentionMask
 from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel
-from levanter.models.loss import next_token_loss
+from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss
 from levanter.optim import AdamConfig
 from levanter.utils.tree_utils import inference_mode
 from test_utils import arrays_only, skip_if_no_torch
@@ -132,12 +132,10 @@ def torch_loss(model, input_ids) -> torch.Tensor:
         return model(input_ids, labels=input_ids)[0]
 
     torch_out = torch_loss(torch_model, torch.from_numpy(onp.array(input.array)).to(torch.int64).unsqueeze(0))
-    causal_mask = AttentionMask.causal()
 
-    def compute_loss(model, input_ids):
-        pred_y = model(input_ids, key=None, attn_mask=causal_mask)
-
-        return next_token_loss(model.Pos, model.Vocab, pred_y, input_ids).scalar()
+    def compute_loss(model: LmHeadModel, input_ids):
+        example = LmExample.causal(input_ids)
+        return compute_next_token_loss(model, example, key=None).scalar()
 
     jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False)
     jax_grad: Gpt2LMHeadModel
diff --git a/tests/test_llama3.py b/tests/test_llama3.py
index 2fae326d1..653ba723c 100644
--- a/tests/test_llama3.py
+++ b/tests/test_llama3.py
@@ -26,9 +26,10 @@ def get_config(vocab_size=1000):
             "eos_token_id": 128001,
             "hidden_act": "silu",
             "hidden_size": 4096,
+            "head_dim": 64,
             "initializer_range": 0.02,
             "intermediate_size": 14336,
-            "max_position_embeddings": 8192,
+            "max_position_embeddings": 131072,
             "model_type": "llama",
             "num_attention_heads": 32,
             "num_hidden_layers": 32,
@@ -55,6 +56,7 @@ def get_config(vocab_size=1000):
     llama3_8b_config.hidden_size = 16
     llama3_8b_config.intermediate_size = 64
     llama3_8b_config.num_attention_heads = 4
+    llama3_8b_config.head_dim = 4
     llama3_8b_config.num_hidden_layers = 4
     llama3_8b_config.num_key_value_heads = 2
     llama3_8b_config.max_position_embeddings = 128
diff --git a/tests/test_lora.py b/tests/test_lora.py
index f7d852531..b6933f935 100644
--- a/tests/test_lora.py
+++ b/tests/test_lora.py
@@ -74,8 +74,8 @@ def __call__(self, x):
         @staticmethod
         def init(*, key):
             k1, k2 = jax.random.split(key)
-            first = hnn.Linear.init(In, Mid, key=k1)
-            second = hnn.Linear.init(Mid, In, key=k2)
+            first = hnn.Linear.init(In, Mid, key=k1, out_first=True)
+            second = hnn.Linear.init(Mid, In, key=k2, out_first=True)
             return Module(first, second)
 
     Layers = hax.Axis("Layers", 3)
@@ -91,7 +91,7 @@ def init(*, key):
     assert loraized.stacked.first.lora.lora_A.weight.axes == (Layers, hax.Axis("LORA_R", 8), In)
     assert loraized.stacked.first.lora.lora_B.weight.axes == (Layers, Mid, hax.Axis("LORA_R", 8))
 
-    assert loraized.stacked.second.weight.axes == (Layers, Mid, In)
+    assert loraized.stacked.second.weight.axes == (Layers, In, Mid)
     input = hax.random.normal(k0, (In,))
     assert not hax.all(hax.isclose(module.fold(input), loraized.fold(input)))
 
diff --git a/tests/test_loss.py b/tests/test_loss.py
new file mode 100644
index 000000000..30d140ede
--- /dev/null
+++ b/tests/test_loss.py
@@ -0,0 +1,325 @@
+# test_cross_entropy.py
+import math
+
+import equinox
+import jax.numpy as jnp
+import jax.random
+import pytest
+
+import haliax as hax
+from haliax import NamedArray
+
+# Import the functions from your module
+# Replace 'your_module' with the actual module name where your functions are defined
+from levanter.models.loss import _blockwise_cross_entropy_loss, cross_entropy_loss_and_log_normalizers
+from levanter.utils.jax_utils import key_iterator
+
+
+Batch = hax.Axis("batch", size=2)
+Seq = hax.Axis("seq", size=3)
+Embed = hax.Axis("embed", size=8)
+Vocab = hax.Axis("vocab", size=16)
+
+
+@pytest.fixture
+def test_data():
+    """
+    Create synthetic test data for cross-entropy loss computation.
+    """
+
+    key = key_iterator(jax.random.PRNGKey(0))
+
+    # Initialize pred_embeddings with ones
+    pred_embeddings = hax.random.normal(next(key), (Batch, Seq, Embed), dtype=jnp.float32) / math.sqrt(Embed.size)
+
+    # Initialize pred_lm_head with ones
+    pred_lm_head = hax.random.normal(next(key), (Vocab, Embed), dtype=jnp.float32) / math.sqrt(Embed.size)
+
+    # Define true_ids such that the target is always the first token in vocab
+    true_ids = hax.random.randint(next(key), (Batch, Seq), 0, Vocab.size)
+
+    return pred_embeddings, pred_lm_head, true_ids
+
+
+def test_basic_equivalence(test_data):
+    """
+    Test that block-wise loss equals full loss when block_size perfectly divides vocab_size.
+    """
+    pred_embeddings, pred_lm_head, true_ids = test_data
+
+    # Compute full loss
+    logits_full = hax.dot(pred_embeddings, pred_lm_head, axis="embed")
+    target_y_full = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype)
+    loss_full, norm_full = cross_entropy_loss_and_log_normalizers(logits_full, Vocab, target_y_full)
+
+    loss_block, norm_this = _blockwise_cross_entropy_loss(
+        (pred_embeddings, pred_lm_head),
+        Contract=Embed,
+        Label=Vocab,
+        labels_y=true_ids,
+        block_size=8,
+        dtype=pred_embeddings.dtype,
+    )
+
+    # Assert that the losses are close
+    assert hax.all(
+        hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
+    ), "Block-wise loss does not match full loss."
+
+
+def test_single_block(test_data):
+    """
+    Test behavior when vocab_size equals block_size.
+    """
+    pred_embeddings, pred_lm_head, true_ids = test_data
+
+    # Compute full loss
+    loss_full, sumexp_full = _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids)
+
+    # Compute block-wise loss with block_size=4 (vocab_size=4)
+    with jax.disable_jit():
+        loss_block, sumexp_block = _blockwise_cross_entropy_loss(
+            (pred_embeddings, pred_lm_head),
+            Contract=Embed,
+            Label=Vocab,
+            labels_y=true_ids,
+            block_size=Vocab.size,
+            dtype=pred_embeddings.dtype,
+        )
+
+    # Assert that the losses are close
+    assert hax.all(
+        hax.isclose(sumexp_full, sumexp_block, atol=1e-3, rtol=1e-3)
+    ), "Single block-wise sumexp does not match full sumexp."
+    assert hax.all(
+        hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
+    ), "Single block-wise loss does not match full loss."
+
+
+def _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids):
+    logits_full = hax.dot(pred_embeddings, pred_lm_head, axis="embed")
+    target_y_full = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype)
+    loss_full, sumexp_full = cross_entropy_loss_and_log_normalizers(logits_full, Vocab, target_y_full)
+    return loss_full, sumexp_full
+
+
+def test_multiple_blocks(test_data):
+    """
+    Test block-wise loss with multiple blocks.
+    """
+    pred_embeddings, pred_lm_head, true_ids = test_data
+
+    # Compute full loss
+    loss_full, logz_full = _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids)
+
+    # Compute block-wise loss with block_size=1 (vocab_size=4)
+    loss_block, logz_block = _blockwise_cross_entropy_loss(
+        (pred_embeddings, pred_lm_head),
+        Contract=Embed,
+        Label=Vocab,
+        labels_y=true_ids,
+        block_size=1,
+        dtype=pred_embeddings.dtype,
+    )
+
+    # Assert that the losses are close
+    assert hax.all(
+        hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)
+    ), "Multiple block-wise logz does not match full logz."
+    assert hax.all(
+        hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
+    ), "Multiple block-wise loss does not match full loss."
+
+
+def test_block_size_not_dividing_vocab(test_data):
+    pred_embeddings, pred_lm_head, true_ids = test_data
+
+    # Set block_size that does not divide vocab_size
+    block_size = 3  # vocab_size=4
+
+    # should be fine now
+    loss_block, logz_block = _blockwise_cross_entropy_loss(
+        (pred_embeddings, pred_lm_head),
+        Contract=Embed,
+        Label=Vocab,
+        labels_y=true_ids,
+        block_size=block_size,
+        dtype=pred_embeddings.dtype,
+    )
+
+    # Compute full loss
+    loss_full, logz_full = cross_entropy_loss_and_log_normalizers(
+        pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"),
+        Label=Vocab,
+        target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype),
+    )
+
+    # Assert that the losses are close
+    assert hax.all(
+        hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
+    ), "Block-wise loss does not match full loss."
+    assert hax.all(
+        hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)
+    ), "Block-wise logz does not match full logz."
+
+
+def test_vocab_size_less_than_block_size(test_data):
+    """
+    Test behavior when vocab_size is less than block_size.
+    """
+    pred_embeddings, pred_lm_head, true_ids = test_data
+
+    # Set block_size greater than vocab_size
+    block_size = 5  # vocab_size=4
+
+    # should be fine now
+    loss_block, logz_block = _blockwise_cross_entropy_loss(
+        (pred_embeddings, pred_lm_head),
+        Contract=Embed,
+        Label=Vocab,
+        labels_y=true_ids,
+        block_size=block_size,
+        dtype=pred_embeddings.dtype,
+    )
+
+    # Compute full loss
+    loss_full, logz_full = cross_entropy_loss_and_log_normalizers(
+        pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"),
+        Label=Vocab,
+        target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype),
+    )
+
+    # Assert that the losses are close
+    assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)), "loss does not match full loss."
+    assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)), "logz does not match full logz."
+
+
+def test_large_vocab():
+    """
+    Test block-wise loss with a larger vocabulary.
+    """
+    Batch = hax.Axis("batch", size=4)
+    Seq = hax.Axis("seq", size=5)
+    Embed = hax.Axis("embed", size=6)
+    Vocab = hax.Axis("vocab", size=12)
+
+    pred_embeddings = NamedArray(
+        jnp.ones((Batch.size, Seq.size, Embed.size)),
+        axes=(Batch, Seq, Embed),
+    )
+    pred_lm_head = NamedArray(
+        jnp.ones((Embed.size, Vocab.size)),
+        axes=(Embed, Vocab),
+    )
+    true_ids = NamedArray(
+        jnp.zeros((Batch.size, Seq.size), dtype=jnp.int32),
+        axes=(Batch, Seq),
+    )
+
+    # Compute full loss
+    loss_full, logz_full = cross_entropy_loss_and_log_normalizers(
+        pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"),
+        Label=Vocab,
+        target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype),
+    )
+
+    # Compute block-wise loss with block_size=3 (vocab_size=12 is divisible by 3)
+    loss_block, logz_block = _blockwise_cross_entropy_loss(
+        (pred_embeddings, pred_lm_head),
+        Contract=Embed,
+        Label=Vocab,
+        labels_y=true_ids,
+        block_size=3,
+        dtype=pred_embeddings.dtype,
+    )
+
+    # Assert that the losses are close
+    assert hax.all(
+        hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
+    ), "Large vocab block-wise loss does not match full loss."
+    assert hax.all(
+        hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)
+    ), "Large vocab block-wise logz does not match full logz."
+
+
+@pytest.mark.parametrize("block_size", [1, 2, 3, 4, 5])
+def test_gradient_block_cross_entropy(block_size, test_data):
+    """
+    Test the gradient of block-wise cross-entropy loss.
+    """
+    pred_embeddings, pred_lm_head, true_ids = test_data
+
+    # Compute block-wise loss
+    def custom_fn(pred):
+        pred_embeddings, pred_lm_head = pred
+        a, b = _blockwise_cross_entropy_loss(
+            (pred_embeddings, pred_lm_head),
+            Contract=Embed,
+            Label=Vocab,
+            labels_y=true_ids,
+            block_size=block_size,
+            dtype=pred_embeddings.dtype,
+        )
+
+        return (a.mean() + b.mean()).scalar()
+
+    g_embed, g_head, = equinox.filter_grad(
+        custom_fn
+    )((pred_embeddings, pred_lm_head))
+
+    # compute directly
+
+    def direct_fn(pred):
+        pred_embeddings, pred_lm_head = pred
+        logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed")
+        target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype)
+        loss, logz = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y)
+        return (loss.mean() + logz.mean()).scalar()
+
+    g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head))
+
+    assert hax.all(
+        hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3)
+    ), "Gradient of embeddings does not match."
+    assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match."
+
+
+def test_grad_loss_without_logz(test_data):
+    """
+    Test the gradient of block-wise cross-entropy loss without logz.
+    """
+    pred_embeddings, pred_lm_head, true_ids = test_data
+
+    # Compute block-wise loss
+    def custom_fn(pred):
+        pred_embeddings, pred_lm_head = pred
+        a, b = _blockwise_cross_entropy_loss(
+            (pred_embeddings, pred_lm_head),
+            Contract=Embed,
+            Label=Vocab,
+            labels_y=true_ids,
+            block_size=2,
+            dtype=pred_embeddings.dtype,
+        )
+
+        return a.mean().scalar()
+
+    g_embed, g_head, = equinox.filter_grad(
+        custom_fn
+    )((pred_embeddings, pred_lm_head))
+
+    # compute directly
+
+    def direct_fn(pred):
+        pred_embeddings, pred_lm_head = pred
+        logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed", preferred_element_type=pred_embeddings.dtype)
+        target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype)
+        loss, _ = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y)
+        return loss.mean().scalar()
+
+    g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head))
+
+    assert hax.all(
+        hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3)
+    ), "Gradient of embeddings does not match."
+    assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match."
diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py
index c1eb73670..086de48e1 100644
--- a/tests/test_new_cache.py
+++ b/tests/test_new_cache.py
@@ -1,6 +1,4 @@
 import asyncio
-import copy
-import os
 import tempfile
 from typing import Any, Dict, Iterator, Sequence
 
@@ -10,17 +8,7 @@
 
 from levanter.data import BatchProcessor, ShardedDataSource, batched
 from levanter.data.sharded_datasource import TextUrlDataSource
-from levanter.store.cache import (
-    LEDGER_FILE_NAME,
-    CacheLedger,
-    CacheOptions,
-    SerialCacheWriter,
-    ShardedCacheWriter,
-    TreeStore,
-    _get_builder_actor,
-    _serialize_json_and_commit,
-    build_or_load_cache,
-)
+from levanter.store.cache import CacheOptions, SerialCacheWriter, TreeStore, _get_builder_actor, build_or_load_cache
 from levanter.utils.py_utils import logical_cpu_core_count
 
 
@@ -140,13 +128,13 @@ def test_full_end_to_end_cache():
     with td as tmpdir:
         ray_ds = build_or_load_cache(
             tmpdir,
-            SimpleShardSource(num_shards=2),
+            SimpleShardSource(num_shards=15),
             TestProcessor(),
             await_finished=True,
-            options=CacheOptions.no_fanciness(8),
+            options=CacheOptions(num_shard_groups=3, batch_size=8),
         )
 
-        expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=2), 8)
+        expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=15))
 
         all_data = ray_ds[:]
 
@@ -162,15 +150,14 @@ def test_full_end_to_end_cache_with_groups():
             SimpleShardSource(num_shards=5),
             TestProcessor(),
             await_finished=True,
-            options=CacheOptions(num_shard_groups=2, batch_size=8, shard_order_randomization_key=None),
+            options=CacheOptions(num_shard_groups=2, batch_size=8),
         )
 
-        expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=5), 8)
+        expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=5))
 
         all_data = ray_ds[:]
 
-        # check_datasets_equal(all_data, expected)
-        assert len(all_data) == len(list(expected))
+        check_datasets_equal(all_data, expected)
 
 
 @pytest.mark.ray
@@ -204,7 +191,6 @@ class _CustomException(Exception):
 
 
 @pytest.mark.ray
-@pytest.mark.skip("This test segfaults in CI. I think a ray bug")
 def test_cache_recover_from_crash():
     class CrashingShardSource(ShardedDataSource[list[int]]):
         def __init__(self, crash_point: int):
@@ -218,7 +204,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
             # parse the shard name to get the shard number
             shard_num = int(shard_name.split("_")[1])
             for i in range(10):
-                if shard_num * 10 + i == self.crash_point:
+                if i == self.crash_point:
                     raise _CustomException(f"Crashing at {shard_num} {i} {self.crash_point}")
                 if i >= row:
                     yield [shard_num * 10 + i] * 10
@@ -226,7 +212,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
     with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as tmpdir2:
         source = CrashingShardSource(4)
         with pytest.raises(_CustomException):
-            build_or_load_cache(tmpdir, source, TestProcessor())
+            build_or_load_cache(tmpdir, source, TestProcessor(), CacheOptions(target_size_per_flush=1))
 
         # kill the broker actor so that we can test recovery
         ray.kill(
@@ -244,11 +230,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
         )
 
         # testing this doesn't throw
-        source = CrashingShardSource(1000)
+        source = CrashingShardSource(100000)
         reader1 = build_or_load_cache(tmpdir, source, TestProcessor(), await_finished=True)
 
         # compare to the original with no crash
-        reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(), TestProcessor(), await_finished=True)
+        reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(num_shards=4), TestProcessor(), await_finished=True)
 
         check_datasets_equal(reader1, reader2)
 
@@ -295,7 +281,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
         # now block until the cache is done
         cache.await_finished(timeout=30)
 
-        expected = process_interleave(processor, SlowShardSource(), 16)
+        expected = simple_process(processor, SlowShardSource())
 
         check_datasets_equal(list(cache[:]), expected)
 
@@ -334,13 +320,12 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
             SlowShardSource(),
             TestProcessor(),
             await_finished=False,
-            force_flush=True,
-            options=CacheOptions.no_fanciness(5),
-        )  # we need force_flush to ensure the cache is written to disk
+            options=CacheOptions(target_size_per_flush=1, batch_size=1),
+        )
 
         # read the first 10 elements
         # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)]
-        first_10 = list(await cache.get_batch(range(0, 10)))
+        first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=30.0))
 
         for i, x in enumerate(first_10):
             np.testing.assert_array_equal(x["test"], np.array([i] * 10))
@@ -353,7 +338,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
 
         # now ensure we can get the next 10 elements, which will be
         # [{"test": np.array([i] * 10)} for i in range(10, 20)]
-        batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10)
+        batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10.0)
 
         for i, x in enumerate(batch):
             np.testing.assert_array_equal(x["test"], np.array([i + 10] * 10))
@@ -364,7 +349,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
         cache.await_finished(timeout=10)
 
 
-@pytest.mark.skip("This test segfaults in CI. I think a ray bug")
 @pytest.mark.ray
 def test_shard_cache_crashes_if_processor_throws():
     class ThrowingProcessor(SimpleProcessor):
@@ -398,7 +382,6 @@ def test_shard_cache_fails_with_multiple_shards_with_the_same_name():
             build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True)
 
 
-@pytest.mark.skip("This test segfaults in CI. I think a ray bug")
 @pytest.mark.ray
 @pytest.mark.asyncio
 async def test_shard_cache_fails_gracefully_with_unknown_file_type_async():
@@ -451,89 +434,3 @@ def test_shard_cache_fails_gracefully_with_unknown_file_type():
             cache.await_finished(timeout=10)
 
         del cache
-
-
-def test_sharded_cache_writer():
-    with tempfile.TemporaryDirectory() as tmpdir:
-        source = SimpleShardSource(num_shards=4)
-        processor = SimpleProcessor()
-        ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(8))
-
-        exemplar = {"data": np.array([0], dtype=np.int64)}
-
-        writer = ShardedCacheWriter(tmpdir, ledger, exemplar)
-        for shard_name in source.shard_names:
-            for ex in batched(source.open_shard(shard_name), ledger.metadata.options.batch_size):
-                writer.write_batch(shard_name, processor(ex))
-
-        for shard_name in source.shard_names:
-            writer.finish_shard(shard_name, source._rows_per_shard)
-
-        store = writer.finish()
-
-        data_path = store.path
-
-        del store
-
-        builder = TreeStore.open(exemplar, data_path, mode="r")
-
-        assert len(builder) == 40
-
-        for i, x in enumerate(builder):
-            np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10))
-
-        # check totals for the ledger
-        ledger = writer.ledger
-        assert ledger.total_num_rows == 40
-        assert ledger.is_finished
-
-        for shard_name in source.shard_names:
-            assert ledger.shard_rows[shard_name] == 10
-
-
-def test_sharded_cache_writer_trims_on_resume():
-    with tempfile.TemporaryDirectory() as tmpdir:
-        source = SimpleShardSource(num_shards=4)
-        processor = SimpleProcessor()
-
-        exemplar = {"data": np.array([0], dtype=np.int64)}
-
-        ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(batch_size=8))
-
-        writer = ShardedCacheWriter(tmpdir, ledger, exemplar)
-        for shard_name in source.shard_names:
-            for ex in batched(source.open_shard(shard_name), 8):
-                writer.write_batch(shard_name, processor(ex))
-
-        for shard_name in source.shard_names:
-            writer.finish_shard(shard_name, 10)
-
-        writer.finish()
-
-        # now deliberately truncate the ledger a bit
-        ledger = copy.deepcopy(writer.ledger)
-        assert ledger.total_num_rows == 40
-        assert ledger.is_finished
-        ledger.total_num_rows = 24
-        ledger.shard_rows["shard_0"] = 8
-        ledger.shard_rows["shard_1"] = 8
-        ledger.shard_rows["shard_2"] = 8
-        ledger.shard_rows["shard_3"] = 0
-        ledger.is_finished = False
-
-        _serialize_json_and_commit(os.path.join(tmpdir, LEDGER_FILE_NAME), ledger)
-
-        writer = ShardedCacheWriter(tmpdir, ledger, exemplar)
-
-        # ensure it got truncated
-        assert writer.ledger.total_num_rows == 24
-        assert writer.ledger.is_finished is False
-        assert writer.ledger.shard_rows["shard_0"] == 8
-        assert writer.ledger.shard_rows["shard_1"] == 8
-        assert writer.ledger.shard_rows["shard_2"] == 8
-        assert writer.ledger.shard_rows["shard_3"] == 0
-
-        new_store = writer._tree_store
-        new_data = new_store[:]
-
-        assert len(new_data) == 24
diff --git a/tests/test_text.py b/tests/test_text.py
index a2645c1f9..e4e51acbc 100644
--- a/tests/test_text.py
+++ b/tests/test_text.py
@@ -26,6 +26,7 @@ def test_dont_blow_up_without_validation_set():
 def test_lm_example_handles_ignore_id():
     Pos = hax.Axis("Pos", 10)
     Vocab = hax.Axis("vocab", Pos.size + 1)
+    Embed = hax.Axis("embed", 10)
     tokens = hax.arange(Pos, dtype=jnp.int32)
 
     ignore_id = 6
@@ -34,11 +35,12 @@ def test_lm_example_handles_ignore_id():
     ex_no_ignore = LmExample.causal(tokens)
     assert ex_ignore.loss_mask[Pos, ignore_id - 1] == 0
 
-    distr = -100 * hax.nn.one_hot(ignore_id, Vocab)
-    distr = distr.broadcast_axis(Pos)
+    logits = hax.ones((Pos, Embed))
+    lm_head = hax.zeros((Embed, Vocab))
+    lm_head = lm_head.at[Vocab, ignore_id].set(-100)
 
-    ignored_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_ignore.loss_mask)
-    no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask)
+    ignored_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask)
+    no_ignore_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask)
 
     assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size