diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 3b9789cda..ff1ce154f 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -729,7 +729,7 @@ def train_set( monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, - epochs: bool = False, + epochs: int = 0, ) -> AsyncDataset[np.ndarray]: ds = self.token_seq_dataset("train", seq_len, monitors) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 54a39700e..87e6cdc13 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -128,11 +128,12 @@ def main(config: TrainLmConfig): ) - # add epoch logging - total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) - trainer.add_hook( - callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 - ) + # add epoch logging if epochs specified + if config.epoch > 0: + total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) + trainer.add_hook( + callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 + ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of