diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 0c6b31ed..a75226fc 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -9,8 +9,7 @@ from modalities.logging_broker.messages import BatchProgressUpdate, ExperimentStatus, MessageTypes from modalities.logging_broker.publisher import MessagePublisher from modalities.models.model import model_predict_batch -from modalities.running_env.fsdp.reducer import Reducer -from modalities.trainer import ThroughputAggregationKeys +from modalities.trainer import AggregationKeys from modalities.util import Aggregator, TimeRecorder @@ -46,17 +45,13 @@ def evaluate( result_dict: Dict[str, EvaluationResultBatch] = {} model.eval() - device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") - for data_loader in data_loaders: - cumulated_loss = torch.zeros(3).to(device) - Evaluator._publish_progress( batch_progress_publisher=self.batch_progress_publisher, eval_step_id=0, # Reset progress bar dataloader_tag=data_loader.dataloader_tag, ) - thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() + score_aggregator = Aggregator[AggregationKeys]() with TimeRecorder() as forward_backward_timer_recorder: for batch_id, batch in enumerate(data_loader): batch_loss = self.evaluate_batch( @@ -65,35 +60,43 @@ def evaluate( loss_fun=loss_fun, ) - cumulated_loss[0] += batch_loss.item() # sum up batch loss - cumulated_loss[1] += 1 - batch_length_tensor = torch.tensor(len(batch)).to(device) - thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) + score_aggregator.add_value(key=AggregationKeys.CUMM_LOSS, value=batch_loss.item()) + # This works, because we always drop the last batch in case it has less samples than the batch size + score_aggregator.add_value(key=AggregationKeys.NUM_STEPS, value=1) + + score_aggregator.add_value(key=AggregationKeys.NUM_SAMPLES, value=len(batch)) Evaluator._publish_progress( batch_progress_publisher=self.batch_progress_publisher, eval_step_id=batch_id, dataloader_tag=data_loader.dataloader_tag, ) - # TODO: insert reducer from outside so Evaluator is independent of FSDP - total_loss = Reducer.reduce( - tensor=cumulated_loss, - operation=dist.ReduceOp.SUM, - post_processing_fun=lambda t: t[0] / t[1], + + score_aggregator.add_value( + key=AggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_timer_recorder.delta_t + ) + + # reduce the scores with the respective reduction operation + sum_reduced_scores = score_aggregator.get_all_reduced_values( + keys=[AggregationKeys.NUM_SAMPLES, AggregationKeys.NUM_STEPS, AggregationKeys.CUMM_LOSS], + reduce_operation=dist.ReduceOp.SUM, + ) + + max_reduced_scores = score_aggregator.get_all_reduced_values( + keys=[AggregationKeys.FORWARD_BACKWARD_TIME], reduce_operation=dist.ReduceOp.MAX ) - forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) - thoughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time + # calculate the metric scores for logging + num_samples_per_second = ( + sum_reduced_scores[AggregationKeys.NUM_SAMPLES] + / max_reduced_scores[AggregationKeys.FORWARD_BACKWARD_TIME] ) - synced_num_samples = thoughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES) - synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( - ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX + eval_loss_avg = ( + sum_reduced_scores[AggregationKeys.CUMM_LOSS] / sum_reduced_scores[AggregationKeys.NUM_STEPS] ) - num_samples_per_second = synced_num_samples / synced_forward_backward_time evaluation_result = EvaluationResultBatch( - losses={loss_fun.tag: total_loss}, + losses={loss_fun.tag: eval_loss_avg}, # TODO: hardcoded metric key throughput_metrics={"evaluation_num_samples_per_second": num_samples_per_second}, dataloader_tag=data_loader.dataloader_tag, diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c77f45ba..19250daf 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -14,15 +14,18 @@ from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss from modalities.models.model import model_predict_batch -from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF from modalities.util import Aggregator, TimeRecorder -class ThroughputAggregationKeys(Enum): +class AggregationKeys(Enum): NUM_SAMPLES = "NUM_SAMPLES" FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME" + NUM_STEPS = "NUM_STEPS" + CUMM_LOSS = "CUMM_LOSS" + LAST_BATCH_LOSS = "LAST_BATCH_LOSS" + class Trainer: def __init__( @@ -74,11 +77,11 @@ def train( checkpointing_callback: Callable[[int], None], ): model.train() - cumulated_losses = self._reset_tracked_losses() + # cumulated_losses = self._reset_tracked_losses() - thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() + score_aggregator = Aggregator[AggregationKeys]() - device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") + torch.device(self.local_rank if torch.cuda.is_available() else "cpu") # batch loop batch: DatasetBatch @@ -102,16 +105,15 @@ def train( ) forward_backward_time_recorder.stop() # Save the batch loss - cumulated_losses[0] += batch_loss.item() + score_aggregator.add_value(key=AggregationKeys.CUMM_LOSS, value=batch_loss.item()) # This works, because we always drop the last batch in case it has less samples than the batch size - cumulated_losses[-1] += 1 # number of local batches + score_aggregator.add_value(key=AggregationKeys.NUM_STEPS, value=1) # gradient norm is already synced across all ranks if gradient_norm_score is not None: gradient_norm_scores.append(gradient_norm_score.item()) - batch_length_tensor = torch.tensor(len(batch)).to(device) - thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) + score_aggregator.add_value(key=AggregationKeys.NUM_SAMPLES, value=len(batch)) self._publish_progress( batch_progress_publisher=self.batch_progress_publisher, @@ -121,33 +123,39 @@ def train( # Check, if model should be evaluated if (train_step_id + 1) % global_training_log_interval_in_steps == 0: - forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) + # add the loss for the LAST batch + score_aggregator.add_value(key=AggregationKeys.LAST_BATCH_LOSS, value=batch_loss.item()) + score_aggregator.add_value( + key=AggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time_recorder.delta_t + ) + forward_backward_time_recorder.reset() - thoughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time + # reduce the scores with the respective reduction operation + sum_reduced_scores = score_aggregator.get_all_reduced_values( + keys=[ + AggregationKeys.NUM_SAMPLES, + AggregationKeys.NUM_STEPS, + AggregationKeys.CUMM_LOSS, + AggregationKeys.LAST_BATCH_LOSS, + ], + reduce_operation=dist.ReduceOp.SUM, ) - synced_num_samples = thoughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES) - synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( - ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX - ) - synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time - # TODO: insert reducer from outside so Trainer is independent of FSDP - # add the loss and gradient norm for the LAST batch - cumulated_losses[1] = batch_loss.item() - - reduced_losses = Reducer.reduce( - tensor=cumulated_losses, - operation=dist.ReduceOp.SUM, - # 1.) summed batch loss / (num batches * world size) - # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]), + + max_reduced_scores = score_aggregator.get_all_reduced_values( + keys=[AggregationKeys.FORWARD_BACKWARD_TIME], reduce_operation=dist.ReduceOp.MAX ) - train_loss_avg, train_loss_last_batch = ( - reduced_losses[0], - reduced_losses[1], + # calculate the metric scores for logging + synced_num_samples_per_second = ( + sum_reduced_scores[AggregationKeys.NUM_SAMPLES] + / max_reduced_scores[AggregationKeys.FORWARD_BACKWARD_TIME] ) + train_loss_avg = ( + sum_reduced_scores[AggregationKeys.CUMM_LOSS] / sum_reduced_scores[AggregationKeys.NUM_STEPS] + ) + train_loss_last_batch = sum_reduced_scores[AggregationKeys.LAST_BATCH_LOSS] / dist.get_world_size() + losses = { f"{loss_fun.tag} average": train_loss_avg, f"{loss_fun.tag} last step": train_loss_last_batch, @@ -179,10 +187,9 @@ def train( evaluation_result_publisher=self.evaluation_result_publisher, evaluation_result=training_metrics, ) - thoughput_aggregator.remove_keys() + score_aggregator.remove_keys() model.train() - cumulated_losses = self._reset_tracked_losses() evaluation_callback(train_step_id=train_step_id) checkpointing_callback(train_step_id=train_step_id) @@ -190,16 +197,6 @@ def train( # via the dataloader. forward_backward_time_recorder.start() - def _reset_tracked_losses(self): - # TODO: we should handle the device assignment more centrally. - # summed lcoal losses, loss of last local batch, number of local batches (i.e., number of steps) - cumulated_loss_and_gradient_norm = torch.zeros(3) - if torch.cuda.is_available(): - cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device(self.local_rank)) - else: - cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to("cpu") - return cumulated_loss_and_gradient_norm - @staticmethod def _publish_progress( batch_progress_publisher: MessagePublisher[BatchProgressUpdate], diff --git a/src/modalities/util.py b/src/modalities/util.py index e769262d..d0299be1 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import Enum from types import TracebackType -from typing import Callable, Dict, Generic, Type, TypeVar +from typing import Dict, Generic, List, Type, TypeVar import torch import torch.distributed as dist @@ -108,31 +108,27 @@ def __repr__(self) -> str: class Aggregator(Generic[T]): def __init__(self): - self.key_to_value: Dict[T, torch.Tensor] = {} + self.key_to_value: Dict[T, float] = {} - def add_value(self, key: T, value: torch.Tensor): + def add_value(self, key: T, value: float | int): if key not in self.key_to_value: - self.key_to_value[key] = value - else: - self.key_to_value[key] += value + self.key_to_value[key] = 0 - def remove_key(self, key: T): - self.key_to_value.pop(key) + self.key_to_value[key] += value def remove_keys(self): self.key_to_value = {} - def get_all_reduced_value( + def get_all_reduced_values( self, - key: T, + keys: List[T], reduce_operation: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, - postprocessing_fun: None | Callable[[torch.Tensor], torch.Tensor] = None, ) -> torch.Tensor: # we clone the value so that we can always resync the value without side-effects - cloned_value = self.key_to_value[key].clone() + cloned_value = torch.FloatTensor([self.key_to_value[key] for key in keys]).cuda() value = Reducer.reduce( tensor=cloned_value, operation=reduce_operation, - post_processing_fun=postprocessing_fun, # lambda t: t[0] / t[1], ) - return value + reduced_dict = {key: value[i] for i, key in enumerate(keys)} + return reduced_dict