diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 6d3a8a154..49da98456 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -18,8 +18,7 @@ from tqdm_loggable.auto import tqdm import levanter.tracker -from levanter.data import DataLoader -from levanter.data.text import TokenSeqEpochDataset +from levanter.data import DataLoader, AsyncDataset from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig @@ -55,7 +54,7 @@ def log_epoch(step_info: StepInfo): return log_epoch -def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int): +def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int): if not ds.is_finite(): raise ValueError("Epochs don't make sense with an infinite dataset.")