Skip to content

Commit

Permalink
remove changes that break epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 23, 2024
1 parent f0ca163 commit c971ebf
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion config/llama_7b_tulu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ optimizer:
min_lr_ratio: 0.1
warmup: 5000

epoch: False
epoch: 0
2 changes: 0 additions & 2 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def log_epoch(step_info: StepInfo):


def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int):
if not ds.is_finite():
raise ValueError("Epochs don't make sense with an infinite dataset.")

def log_length():
# If ds.async_len() is the only option, run it in an event loop inside the thread
Expand Down
4 changes: 1 addition & 3 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ class EpochDataset(AsyncDataset[T_co]):
:param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely.
"""
def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
if dataset.is_finite():
raise ValueError("Cannot apply epoching to a finite dataset.")
self.dataset = dataset
self.max_epochs = max_epochs

Expand Down Expand Up @@ -737,7 +735,7 @@ def train_set(
ds = self.token_seq_dataset("train", seq_len, monitors)
if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds)
ds = EpochDataset(ds, max_epochs=epochs)

# add epoch flag here.
if ds is None:
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TrainLmConfig:
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
initialize_from_checkpoint_path: Optional[str] = None
# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: bool | int = False
epoch: int = 0


def main(config: TrainLmConfig):
Expand Down

0 comments on commit c971ebf

Please sign in to comment.