Skip to content

Commit

Permalink
add suggested fix from david
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 23, 2024
1 parent fd18cae commit 1706803
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from tqdm_loggable.auto import tqdm

import levanter.tracker
from levanter.data import DataLoader
from levanter.data.text import TokenSeqEpochDataset
from levanter.data import DataLoader, AsyncDataset
from levanter.logging import save_xla_dumps_to_wandb
from levanter.tracker.helpers import log_optimizer_hyperparams
from levanter.tracker.wandb import WandbConfig
Expand Down Expand Up @@ -55,7 +54,7 @@ def log_epoch(step_info: StepInfo):
return log_epoch


def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int):
def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int):
if not ds.is_finite():
raise ValueError("Epochs don't make sense with an infinite dataset.")

Expand Down

0 comments on commit 1706803

Please sign in to comment.