Skip to content

Commit

Permalink
[SD] log number of samples instead of number of iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadki committed Feb 6, 2024
1 parent 00f04c5 commit 90fc1ed
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions stable_diffusion/mlperf_logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
batch: Any, batch_idx: int) -> None:
if trainer.global_step % self.train_log_interval == 0:
self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)})

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
if trainer.global_step % self.train_log_interval == 0:
logs = trainer.callback_metrics
self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)})
self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)})
self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer)})

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.logger.start(key=mllog_constants.EVAL_START, value=trainer.global_step)
Expand All @@ -123,11 +123,11 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
if "validation/fid" in logs:
self.logger.event(key=mllog_constants.EVAL_ACCURACY,
value=logs["validation/fid"].item(),
metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "FID"})
metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer), "metric": "FID"})
if "validation/clip" in logs:
self.logger.event(key=mllog_constants.EVAL_ACCURACY,
value=logs["validation/clip"].item(),
metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "CLIP"})
metadata={mllog_constants.EPOCH_NUM: self._samples_count(trainer), "metric": "CLIP"})
self.logger.end(key=mllog_constants.EVAL_STOP, value=trainer.global_step)

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -146,5 +146,11 @@ def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin
if batch_idx % self.validation_log_interval == 0:
self.logger.end(key=mllog_constants.BLOCK_STOP, value="validation_step", metadata={mllog_constants.STEP_NUM: batch_idx})

def _samples_count(self, trainer: "pl.Trainer") -> int:
batch_size_per_gpu = trainer.train_dataloader.batch_size
num_gpus = trainer.num_gpus if trainer.num_gpus else 1
global_batch_size = batch_size_per_gpu * num_gpus

return global_batch_size * trainer.global_step

mllogger = SDLogger()

0 comments on commit 90fc1ed

Please sign in to comment.