Skip to content

Commit

Permalink
refactor: refactored the trainer to work with the new TrainingProgres…
Browse files Browse the repository at this point in the history
…s object
  • Loading branch information
le1nux committed Sep 11, 2024
1 parent b4ce789 commit bb91e9e
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

from modalities.batch import DatasetBatch, EvaluationResultBatch
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import BatchProgressUpdate, ExperimentStatus, MessageTypes
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.loss_functions import Loss
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
from modalities.training.training_progress import TrainingProgress
from modalities.util import Aggregator, TimeRecorder, print_rank_0
from modalities.utils.mfu import compute_mfu, get_theoretical_flops_per_token, get_theoretical_gpu_peak_performance

Expand All @@ -29,33 +30,41 @@ class Trainer:
def __init__(
self,
global_rank: int,
batch_progress_publisher: MessagePublisher[BatchProgressUpdate],
progress_publisher: MessagePublisher[ProgressUpdate],
evaluation_result_publisher: MessagePublisher[EvaluationResultBatch],
gradient_acc_steps: int,
global_num_tokens_per_train_step: int,
num_seen_train_steps: int,
global_num_seen_tokens: int,
num_target_steps: int,
num_target_tokens: int,
gradient_clipper: GradientClipperIF,
) -> None:
"""
Initializes the Trainer object.
Args:
global_rank (int): The global rank to which operates the trainer object.
batch_progress_publisher (MessagePublisher[BatchProgressUpdate]): The publisher for batch progress
updates.
progress_publisher (MessagePublisher[ProgressUpdate]): The publisher for progress updates.
evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]):
The publisher for evaluation result batches.
gradient_acc_steps (int): The number of gradient accumulation steps.
global_num_tokens_per_train_step (int): The number of global tokens per training step.
target_train_steps (int): The target number of training steps.
gradient_clipper (GradientClipperIF): The gradient clipper.
Returns:
None
"""
self.global_rank = global_rank
self.batch_progress_publisher = batch_progress_publisher
self.progress_publisher = progress_publisher
self.evaluation_result_publisher = evaluation_result_publisher
self.gradient_acc_steps = gradient_acc_steps
self.global_num_tokens_per_train_step = global_num_tokens_per_train_step
self.num_seen_train_steps = num_seen_train_steps
self.num_target_steps = num_target_steps
self.num_target_tokens = num_target_tokens
self.global_num_seen_tokens = global_num_seen_tokens
self.gradient_clipper = gradient_clipper

@staticmethod
Expand Down Expand Up @@ -128,8 +137,8 @@ def train(
scheduler: LRScheduler,
loss_fun: Loss,
training_log_interval_in_steps: int,
evaluation_callback: Callable[[int], None],
checkpointing_callback: Callable[[int], None],
evaluation_callback: Callable[[TrainingProgress], None],
checkpointing_callback: Callable[[TrainingProgress], None],
):
"""
Trains the model.
Expand All @@ -141,8 +150,8 @@ def train(
scheduler (LRScheduler): The learning rate scheduler.
loss_fun (Loss): The loss function used for training.
training_log_interval_in_steps (int): The interval at which training progress is logged.
evaluation_callback (Callable[[int], None]): A callback function for evaluation.
checkpointing_callback (Callable[[int], None]): A callback function for checkpointing.
evaluation_callback (Callable[[TrainingProgress], None]): A callback function for evaluation.
checkpointing_callback (Callable[[TrainingProgress], None]): A callback function for checkpointing.
Returns:
None
Expand All @@ -166,14 +175,20 @@ def train(
gradient_norm_scores = []

# run evaluation callback and checkpointing callback before the first optimizer step
num_train_steps_done = Trainer._get_num_train_steps_done(
micro_batch_id=train_loader.fast_forward_batch_id - 1, gradient_acc_steps=self.gradient_acc_steps
evaluation_callback(num_train_steps_done=self.num_seen_train_steps)
training_progress = TrainingProgress(
num_seen_steps_previous_run=self.num_seen_train_steps,
num_seen_tokens_previous_run=self.global_num_seen_tokens,
num_seen_steps_current_run=0,
num_seen_tokens_current_run=0,
num_target_steps=self.num_target_steps,
num_target_tokens=self.num_target_tokens,
)
evaluation_callback(num_train_steps_done=num_train_steps_done)
checkpointing_callback(num_train_steps_done=num_train_steps_done)
checkpointing_callback(training_progress=training_progress)

num_steps_todo = self.num_target_steps - self.num_seen_train_steps
# Because we might resume training, we add the starting batch id of the data loader
for micro_batch_id, batch in enumerate(train_loader, start=train_loader.fast_forward_batch_id):
for _, (micro_batch_id, batch) in zip(range(num_steps_todo), enumerate(train_loader)):
# Train single batch
(
step_performed,
Expand All @@ -189,6 +204,8 @@ def train(
micro_batch_id=micro_batch_id,
)
forward_backward_time_recorder.stop()
training_progress.num_seen_steps_current_run = num_train_steps_done
training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done

# Save the batch loss
cumulated_losses[0] += batch_loss.item()
Expand All @@ -203,12 +220,12 @@ def train(
thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)

self._publish_progress(
batch_progress_publisher=self.batch_progress_publisher,
num_train_steps_done=num_train_steps_done,
progress_publisher=self.progress_publisher,
num_train_steps_done=training_progress.num_seen_steps_total,
dataloader_tag=train_loader.dataloader_tag,
)
# Check if model performance should be logged
if num_train_steps_done % training_log_interval_in_steps == 0 and step_performed:
if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed:
forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device)
forward_backward_time_recorder.reset()

Expand Down Expand Up @@ -241,7 +258,7 @@ def train(
"train loss last": train_loss_last_batch,
}

consumed_tokens = torch.Tensor([num_train_steps_done * self.global_num_tokens_per_train_step])
consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total])
metrics = {
"consumed tokens": consumed_tokens,
"grad norm avg": torch.mean(torch.Tensor(gradient_norm_scores)),
Expand All @@ -265,7 +282,7 @@ def train(
"lr mean": torch.tensor(scheduler.get_last_lr()).mean(),
},
dataloader_tag=train_loader.dataloader_tag,
num_train_steps_done=num_train_steps_done,
num_train_steps_done=training_progress.num_seen_steps_total,
)
print_rank_0(training_metrics)
self._publish_evaluation_result(
Expand All @@ -276,8 +293,8 @@ def train(

cumulated_losses = self._reset_tracked_losses()
if step_performed:
evaluation_callback(num_train_steps_done=num_train_steps_done)
checkpointing_callback(num_train_steps_done=num_train_steps_done)
evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)
checkpointing_callback(training_progress=training_progress)
# we start the time recoder here again to also capture the time spend loading
# via the dataloader.
forward_backward_time_recorder.start()
Expand All @@ -295,18 +312,18 @@ def _reset_tracked_losses(self):

@staticmethod
def _publish_progress(
batch_progress_publisher: MessagePublisher[BatchProgressUpdate],
progress_publisher: MessagePublisher[ProgressUpdate],
num_train_steps_done: int,
dataloader_tag: str,
):
# Publishes the progress of the training, i.e., number of training steps done.

payload = BatchProgressUpdate(
payload = ProgressUpdate(
num_steps_done=num_train_steps_done,
experiment_status=ExperimentStatus.TRAIN,
dataloader_tag=dataloader_tag,
)
batch_progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE)
progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE)

@staticmethod
def _publish_evaluation_result(
Expand Down

0 comments on commit bb91e9e

Please sign in to comment.