From 1706803e87a818e1125994b3f6c84e2c9a4f03ee Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 22 Oct 2024 19:58:17 -0700 Subject: [PATCH] add suggested fix from david --- src/levanter/callbacks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 6d3a8a154..49da98456 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -18,8 +18,7 @@ from tqdm_loggable.auto import tqdm import levanter.tracker -from levanter.data import DataLoader -from levanter.data.text import TokenSeqEpochDataset +from levanter.data import DataLoader, AsyncDataset from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig @@ -55,7 +54,7 @@ def log_epoch(step_info: StepInfo): return log_epoch -def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int): +def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int): if not ds.is_finite(): raise ValueError("Epochs don't make sense with an infinite dataset.")