diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml index 48af18b2a..cf333f850 100644 --- a/config/llama_7b_tulu.yaml +++ b/config/llama_7b_tulu.yaml @@ -36,4 +36,4 @@ optimizer: min_lr_ratio: 0.1 warmup: 5000 -epoch: False +epoch: 0 diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 49da98456..2eae0185e 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -55,8 +55,6 @@ def log_epoch(step_info: StepInfo): def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int): - if not ds.is_finite(): - raise ValueError("Epochs don't make sense with an infinite dataset.") def log_length(): # If ds.async_len() is the only option, run it in an event loop inside the thread diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 44931414c..3b9789cda 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -73,8 +73,6 @@ class EpochDataset(AsyncDataset[T_co]): :param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely. """ def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None): - if dataset.is_finite(): - raise ValueError("Cannot apply epoching to a finite dataset.") self.dataset = dataset self.max_epochs = max_epochs @@ -737,7 +735,7 @@ def train_set( ds = self.token_seq_dataset("train", seq_len, monitors) if epochs: logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds) + ds = EpochDataset(ds, max_epochs=epochs) # add epoch flag here. if ds is None: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 6f76482f2..54a39700e 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -54,7 +54,7 @@ class TrainLmConfig: data_seed: Optional[int] = None # if provided, will override the data seed from the trainer initialize_from_checkpoint_path: Optional[str] = None # if provided, will initialize from this checkpoint, used for llama style data mixture - epoch: bool | int = False + epoch: int = 0 def main(config: TrainLmConfig):