diff --git a/examples/sft/sft.py b/examples/sft/sft.py
deleted file mode 100644
index 152781b0b..000000000
--- a/examples/sft/sft.py
+++ /dev/null
@@ -1,205 +0,0 @@
-import logging
-import os
-from dataclasses import dataclass
-from enum import Enum
-from typing import List, Optional
-
-import jax.random as jrandom
-import transformers
-
-import haliax as hax
-from haliax import Axis
-from haliax.partitioning import round_axis_for_partitioning
-
-import levanter
-from levanter import callbacks
-from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
-from levanter.data import PermutationDataset
-from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
-from levanter.main.train_lm import TrainLmConfig
-from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
-from levanter.trainer import Trainer
-
-
-logger = logging.getLogger(__name__)
-
-# Define default special tokens
-DEFAULT_PAD_TOKEN = "[PAD]"
-DEFAULT_EOS_TOKEN = "</s>"
-DEFAULT_BOS_TOKEN = "<s>"
-DEFAULT_UNK_TOKEN = "<unk>"
-
-
-class DatasetType(str, Enum):
-    """Type of dataset to use"""
-
-    HUGGINGFACE = "huggingface"  # Use HF dataset
-    CHAT_JSONL = "chat_jsonl"  # Use JSONL files with chat format
-
-
-@dataclass
-class SFTConfig(TrainLmConfig):
-    # inherit most of the config from TrainLmConfig
-    max_tune_length: int = 2048
-    model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
-    tokenizer: str = "meta-llama/Llama-2-7b-hf"
-
-    # Add dataset type and chat-specific fields
-    dataset_type: DatasetType = DatasetType.HUGGINGFACE
-    chat_train_urls: Optional[List[str]] = None
-    messages_field: str = "messages"
-    input_role: str = "user"
-    output_role: str = "assistant"
-
-
-def train(config: SFTConfig):
-    tokenizer = transformers.AutoTokenizer.from_pretrained(
-        config.tokenizer,
-        model_max_length=config.max_tune_length,
-        padding_side="right",
-        trust_remote_code=True,
-    )
-    logger.info(f"Loaded tokenizer {tokenizer}")
-
-    if config.initialize_from_hf:
-        if config.trainer.initialize_from is not None:
-            raise ValueError("Cannot use both --initialize_from_hf and --initialize_from")
-
-        assert isinstance(config.model, HFCompatConfig)
-
-        converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True)
-        if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab:
-            logger.warning("The tokenizers appear to be different. You may want to check this.")
-        if isinstance(config.initialize_from_hf, str):
-            converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer)
-        else:
-            converter = converter.replaced(tokenizer=tokenizer)
-
-        model_config = converter.default_config
-    elif config.trainer.initialize_from is None:
-        raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
-    else:
-        converter = None
-        model_config = config.model
-
-    levanter.initialize(config)
-
-    num_new_tokens = add_special_tokens(tokenizer)
-    logger.info(f"Added {num_new_tokens} new tokens")
-    # randomness in jax is tightly controlled by "keys" which are the states of the random number generators
-    # this makes deterministic training pretty easy
-    seed = config.trainer.seed
-    data_key, _, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4)
-
-    if config.data_seed is not None:
-        logger.info(f"Overriding data seed with {config.data_seed}")
-        data_key = jrandom.PRNGKey(config.data_seed)
-
-    # Create supervised dataset using generic machinery
-    logger.info("Creating supervised dataset")
-    if config.dataset_type == DatasetType.CHAT_JSONL:
-        assert config.chat_train_urls is not None
-        assert config.supervised_data is not None
-        chat_config = ChatSFTDatasetConfig(
-            cache_dir=config.supervised_data.cache_dir,
-            train_urls=config.chat_train_urls,  # No validation in this config
-            messages_field=config.messages_field,
-            input_role=config.input_role,
-            output_role=config.output_role,
-        )
-        train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos)
-    else:
-        assert config.supervised_data is not None
-        train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos)
-    logger.info("Supervised dataset created")
-    train_dataset = PermutationDataset(train_dataset, data_key)
-
-    # Then wrap for epochs
-    if config.epoch > 0:
-        logger.info(f"Wrapping dataset for {config.epoch} epochs")
-        train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch)
-
-    logger.info("Creating optimizer")
-    optimizer = config.optimizer.build(config.trainer.num_train_steps)
-
-    # Using the trainer as a context manager does 3 things:
-    # 1. Sets the device mesh
-    # 2. Sets the axis mapping (for fsdp)
-    # 3. Sets the global metrics tracker
-    with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer:  # type: ignore
-        parameter_axis_mapping = trainer.parameter_axis_mapping
-
-        # We have two axis_mappings: one for storing the model and optimizer states, and one for compute
-        # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh
-        parameter_axis_mapping = trainer.parameter_axis_mapping
-
-        # some axes we need
-        Pos = config.model.Pos
-
-        # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
-        # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
-        # tokens: gpt-2 has 50257, for example. So we round up.
-        vocab_size = len(tokenizer)
-        Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
-        if config.initialize_from_hf:
-            logger.info(f"Loading pretrained model from {converter.reference_checkpoint}")
-            model: LmHeadModel = converter.load_pretrained(
-                model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype
-            )  # type: ignore
-            model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model)
-            state = trainer.initial_state(training_key, model=model)
-        else:
-            if vocab_size != Vocab.size:
-                logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")
-            state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
-
-        flops_per_token = config.model.flops_per_token(vocab_size)
-        flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
-        trainer.add_hook(
-            callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
-        )
-
-        loader = trainer.data_loader(train_dataset, trainer.TrainBatch)
-
-        if int(state.step) != 0:
-            logger.info(f"Resuming training from step {state.step}")
-            for i in range(state.step):
-                next(loader)
-
-        if config.hf_save_path is not None:
-            # bit gross to reach this far into the config, but it's fine
-            if config.trainer.checkpointer.append_run_id_to_base_path:
-                full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
-            else:
-                full_save_path = config.hf_save_path
-
-            trainer.add_hook(
-                save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
-                every=config.hf_save_steps,
-            )
-
-        trainer.train(state, loader)
-
-
-def add_special_tokens(tokenizer, use_unk_instead_of_adding=False):
-    special_tokens_dict = dict()
-    if use_unk_instead_of_adding:
-        if tokenizer.unk_token is None:
-            raise ValueError("use_unk_instead_of_add is True but tokenizer doesn't have an unk token")
-
-    unk = tokenizer.unk_token if use_unk_instead_of_adding else None
-
-    if tokenizer.pad_token is None:
-        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN if not use_unk_instead_of_adding else unk
-    if tokenizer.eos_token is None:
-        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN if not use_unk_instead_of_adding else unk
-    if tokenizer.bos_token is None:
-        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN if not use_unk_instead_of_adding else unk
-    if tokenizer.unk_token is None:
-        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
-
-    return tokenizer.add_special_tokens(special_tokens_dict)
-
-
-if __name__ == "__main__":
-    levanter.config.main(train)()
diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py
index 629b556c2..51776124e 100644
--- a/src/levanter/main/sft.py
+++ b/src/levanter/main/sft.py
@@ -160,11 +160,6 @@ def train(config: SFTConfig):
 
         loader = trainer.data_loader(train_dataset, trainer.TrainBatch)
 
-        if int(state.step) != 0:
-            logger.info(f"Resuming training from step {state.step}")
-            for i in range(state.step):
-                next(loader)
-
         if config.hf_save_path is not None:
             # bit gross to reach this far into the config, but it's fine
             if config.trainer.checkpointer.append_run_id_to_base_path: