Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter norm logging #122

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions src/modalities/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from modalities.logging_broker.messages import BatchProgressUpdate, ExperimentStatus, MessageTypes
from modalities.logging_broker.publisher import MessagePublisher
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.trainer import ThroughputAggregationKeys
from modalities.trainer import AggregationKeys
from modalities.util import Aggregator, TimeRecorder


Expand Down Expand Up @@ -46,17 +45,13 @@ def evaluate(
result_dict: Dict[str, EvaluationResultBatch] = {}
model.eval()

device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu")

for data_loader in data_loaders:
cumulated_loss = torch.zeros(3).to(device)

Evaluator._publish_progress(
batch_progress_publisher=self.batch_progress_publisher,
eval_step_id=0, # Reset progress bar
dataloader_tag=data_loader.dataloader_tag,
)
thoughput_aggregator = Aggregator[ThroughputAggregationKeys]()
score_aggregator = Aggregator[AggregationKeys]()
with TimeRecorder() as forward_backward_timer_recorder:
for batch_id, batch in enumerate(data_loader):
batch_loss = self.evaluate_batch(
Expand All @@ -65,35 +60,43 @@ def evaluate(
loss_fun=loss_fun,
)

cumulated_loss[0] += batch_loss.item() # sum up batch loss
cumulated_loss[1] += 1
batch_length_tensor = torch.tensor(len(batch)).to(device)
thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)
score_aggregator.add_value(key=AggregationKeys.CUMM_LOSS, value=batch_loss.item())
# This works, because we always drop the last batch in case it has less samples than the batch size
score_aggregator.add_value(key=AggregationKeys.NUM_STEPS, value=1)

score_aggregator.add_value(key=AggregationKeys.NUM_SAMPLES, value=len(batch))

Evaluator._publish_progress(
batch_progress_publisher=self.batch_progress_publisher,
eval_step_id=batch_id,
dataloader_tag=data_loader.dataloader_tag,
)
# TODO: insert reducer from outside so Evaluator is independent of FSDP
total_loss = Reducer.reduce(
tensor=cumulated_loss,
operation=dist.ReduceOp.SUM,
post_processing_fun=lambda t: t[0] / t[1],

score_aggregator.add_value(
key=AggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_timer_recorder.delta_t
)

# reduce the scores with the respective reduction operation
sum_reduced_scores = score_aggregator.get_all_reduced_values(
keys=[AggregationKeys.NUM_SAMPLES, AggregationKeys.NUM_STEPS, AggregationKeys.CUMM_LOSS],
reduce_operation=dist.ReduceOp.SUM,
)

max_reduced_scores = score_aggregator.get_all_reduced_values(
keys=[AggregationKeys.FORWARD_BACKWARD_TIME], reduce_operation=dist.ReduceOp.MAX
)

forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device)
thoughput_aggregator.add_value(
key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time
# calculate the metric scores for logging
num_samples_per_second = (
sum_reduced_scores[AggregationKeys.NUM_SAMPLES]
/ max_reduced_scores[AggregationKeys.FORWARD_BACKWARD_TIME]
)
synced_num_samples = thoughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES)
synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value(
ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX
eval_loss_avg = (
sum_reduced_scores[AggregationKeys.CUMM_LOSS] / sum_reduced_scores[AggregationKeys.NUM_STEPS]
)
num_samples_per_second = synced_num_samples / synced_forward_backward_time

evaluation_result = EvaluationResultBatch(
losses={loss_fun.tag: total_loss},
losses={loss_fun.tag: eval_loss_avg},
# TODO: hardcoded metric key
throughput_metrics={"evaluation_num_samples_per_second": num_samples_per_second},
dataloader_tag=data_loader.dataloader_tag,
Expand Down
81 changes: 39 additions & 42 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
from modalities.logging_broker.publisher import MessagePublisher
from modalities.loss_functions import Loss
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
from modalities.util import Aggregator, TimeRecorder


class ThroughputAggregationKeys(Enum):
class AggregationKeys(Enum):
NUM_SAMPLES = "NUM_SAMPLES"
FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME"

NUM_STEPS = "NUM_STEPS"
CUMM_LOSS = "CUMM_LOSS"
LAST_BATCH_LOSS = "LAST_BATCH_LOSS"


class Trainer:
def __init__(
Expand Down Expand Up @@ -74,11 +77,11 @@ def train(
checkpointing_callback: Callable[[int], None],
):
model.train()
cumulated_losses = self._reset_tracked_losses()
# cumulated_losses = self._reset_tracked_losses()

thoughput_aggregator = Aggregator[ThroughputAggregationKeys]()
score_aggregator = Aggregator[AggregationKeys]()

device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu")
torch.device(self.local_rank if torch.cuda.is_available() else "cpu")

# batch loop
batch: DatasetBatch
Expand All @@ -102,16 +105,15 @@ def train(
)
forward_backward_time_recorder.stop()
# Save the batch loss
cumulated_losses[0] += batch_loss.item()
score_aggregator.add_value(key=AggregationKeys.CUMM_LOSS, value=batch_loss.item())
# This works, because we always drop the last batch in case it has less samples than the batch size
cumulated_losses[-1] += 1 # number of local batches
score_aggregator.add_value(key=AggregationKeys.NUM_STEPS, value=1)

# gradient norm is already synced across all ranks
if gradient_norm_score is not None:
gradient_norm_scores.append(gradient_norm_score.item())

batch_length_tensor = torch.tensor(len(batch)).to(device)
thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)
score_aggregator.add_value(key=AggregationKeys.NUM_SAMPLES, value=len(batch))

self._publish_progress(
batch_progress_publisher=self.batch_progress_publisher,
Expand All @@ -121,33 +123,39 @@ def train(

# Check, if model should be evaluated
if (train_step_id + 1) % global_training_log_interval_in_steps == 0:
forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device)
# add the loss for the LAST batch
score_aggregator.add_value(key=AggregationKeys.LAST_BATCH_LOSS, value=batch_loss.item())
score_aggregator.add_value(
key=AggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time_recorder.delta_t
)

forward_backward_time_recorder.reset()

thoughput_aggregator.add_value(
key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time
# reduce the scores with the respective reduction operation
sum_reduced_scores = score_aggregator.get_all_reduced_values(
keys=[
AggregationKeys.NUM_SAMPLES,
AggregationKeys.NUM_STEPS,
AggregationKeys.CUMM_LOSS,
AggregationKeys.LAST_BATCH_LOSS,
],
reduce_operation=dist.ReduceOp.SUM,
)
synced_num_samples = thoughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES)
synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value(
ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX
)
synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time
# TODO: insert reducer from outside so Trainer is independent of FSDP
# add the loss and gradient norm for the LAST batch
cumulated_losses[1] = batch_loss.item()

reduced_losses = Reducer.reduce(
tensor=cumulated_losses,
operation=dist.ReduceOp.SUM,
# 1.) summed batch loss / (num batches * world size)
# 2.) last batch loss / world size
post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]),

max_reduced_scores = score_aggregator.get_all_reduced_values(
keys=[AggregationKeys.FORWARD_BACKWARD_TIME], reduce_operation=dist.ReduceOp.MAX
)

train_loss_avg, train_loss_last_batch = (
reduced_losses[0],
reduced_losses[1],
# calculate the metric scores for logging
synced_num_samples_per_second = (
sum_reduced_scores[AggregationKeys.NUM_SAMPLES]
/ max_reduced_scores[AggregationKeys.FORWARD_BACKWARD_TIME]
)
train_loss_avg = (
sum_reduced_scores[AggregationKeys.CUMM_LOSS] / sum_reduced_scores[AggregationKeys.NUM_STEPS]
)
train_loss_last_batch = sum_reduced_scores[AggregationKeys.LAST_BATCH_LOSS] / dist.get_world_size()

losses = {
f"{loss_fun.tag} average": train_loss_avg,
f"{loss_fun.tag} last step": train_loss_last_batch,
Expand Down Expand Up @@ -179,27 +187,16 @@ def train(
evaluation_result_publisher=self.evaluation_result_publisher,
evaluation_result=training_metrics,
)
thoughput_aggregator.remove_keys()
score_aggregator.remove_keys()

model.train()
cumulated_losses = self._reset_tracked_losses()

evaluation_callback(train_step_id=train_step_id)
checkpointing_callback(train_step_id=train_step_id)
# we start the time recoder here again to also capture the time spend loading
# via the dataloader.
forward_backward_time_recorder.start()

def _reset_tracked_losses(self):
# TODO: we should handle the device assignment more centrally.
# summed lcoal losses, loss of last local batch, number of local batches (i.e., number of steps)
cumulated_loss_and_gradient_norm = torch.zeros(3)
if torch.cuda.is_available():
cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device(self.local_rank))
else:
cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to("cpu")
return cumulated_loss_and_gradient_norm

@staticmethod
def _publish_progress(
batch_progress_publisher: MessagePublisher[BatchProgressUpdate],
Expand Down
24 changes: 10 additions & 14 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from enum import Enum
from types import TracebackType
from typing import Callable, Dict, Generic, Type, TypeVar
from typing import Dict, Generic, List, Type, TypeVar

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -108,31 +108,27 @@ def __repr__(self) -> str:

class Aggregator(Generic[T]):
def __init__(self):
self.key_to_value: Dict[T, torch.Tensor] = {}
self.key_to_value: Dict[T, float] = {}

def add_value(self, key: T, value: torch.Tensor):
def add_value(self, key: T, value: float | int):
if key not in self.key_to_value:
self.key_to_value[key] = value
else:
self.key_to_value[key] += value
self.key_to_value[key] = 0

def remove_key(self, key: T):
self.key_to_value.pop(key)
self.key_to_value[key] += value

def remove_keys(self):
self.key_to_value = {}

def get_all_reduced_value(
def get_all_reduced_values(
self,
key: T,
keys: List[T],
reduce_operation: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM,
postprocessing_fun: None | Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:
# we clone the value so that we can always resync the value without side-effects
cloned_value = self.key_to_value[key].clone()
cloned_value = torch.FloatTensor([self.key_to_value[key] for key in keys]).cuda()
value = Reducer.reduce(
tensor=cloned_value,
operation=reduce_operation,
post_processing_fun=postprocessing_fun, # lambda t: t[0] / t[1],
)
return value
reduced_dict = {key: value[i] for i, key in enumerate(keys)}
return reduced_dict