Skip to content

Commit

Permalink
correct num_steps count
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani authored and jettjaniak committed Apr 9, 2024
1 parent fa1f52e commit cc7a6c8
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/delphi/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,10 @@ def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext
limit=config.data_config.validation_sample_limit,
)

# derive iteration params (num_batches, num_steps, etc)
num_batches = len(train_ds) // config.batch_size
num_steps = num_batches // config.gradient_accumulation_steps
# derive iteration params
steps_per_epoch = len(train_ds) // config.batch_size
lr_decay_iters = (
config.max_epochs * num_batches
config.max_epochs * steps_per_epoch
) # should be ~=max_iters per Chinchilla

# model init
Expand All @@ -97,7 +96,7 @@ def run_training(config: TrainingConfig) -> tuple[ModelTrainingState, RunContext
ordering_seed=config.batch_ordering_seed,
)
model_training_state.epoch = epoch
for step in tqdm(range(num_steps)):
for step in tqdm(range(steps_per_epoch)):
model_training_state.step = step
if should_save_checkpoint(config, model_training_state):
log_and_save_checkpoint(
Expand Down

0 comments on commit cc7a6c8

Please sign in to comment.