Skip to content

Commit

Permalink
[SD] log number of samples instead of number of iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadki committed Feb 25, 2024
1 parent 00f04c5 commit 20391e0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
31 changes: 20 additions & 11 deletions stable_diffusion/mlperf_logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(self, logger, train_log_interval=5, validation_log_interval=1):
self.logger = mllogger
self.train_log_interval = train_log_interval
self.validation_log_interval = validation_log_interval
self.train_samples_counter = 0
self.val_samples_counter = 0

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.logger.start(mllog_constants.RUN_START)
Expand All @@ -105,46 +107,53 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
batch: Any, batch_idx: int) -> None:
if trainer.global_step % self.train_log_interval == 0:
self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter})

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
batch_size = len(batch['txt'])
total_batch_sizes = pl_module.all_gather(batch_size).sum().item() # Gather batch sizes from all devices
self.train_samples_counter += total_batch_sizes

if trainer.global_step % self.train_log_interval == 0:
logs = trainer.callback_metrics
self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step})
self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter})
self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter})
self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter})

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.logger.start(key=mllog_constants.EVAL_START, value=trainer.global_step)
self.logger.start(key=mllog_constants.EVAL_START, value=self.train_samples_counter, metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter})

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
logs = trainer.callback_metrics
if "validation/fid" in logs:
self.logger.event(key=mllog_constants.EVAL_ACCURACY,
value=logs["validation/fid"].item(),
metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "FID"})
metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter, "metric": "FID"})
if "validation/clip" in logs:
self.logger.event(key=mllog_constants.EVAL_ACCURACY,
value=logs["validation/clip"].item(),
metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "CLIP"})
self.logger.end(key=mllog_constants.EVAL_STOP, value=trainer.global_step)
metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter, "metric": "CLIP"})
self.logger.end(key=mllog_constants.EVAL_STOP, value=self.train_samples_counter, metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter})

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pass
self.val_samples_counter = 0

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pass

def on_validation_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
batch_size = len(batch['caption'])
total_batch_sizes = pl_module.all_gather(batch_size).sum().item() # Gather batch sizes from all devices
self.val_samples_counter += total_batch_sizes
if batch_idx % self.validation_log_interval == 0:
self.logger.start(key=mllog_constants.BLOCK_START, value="validation_step", metadata={mllog_constants.STEP_NUM: batch_idx})
self.logger.start(key=mllog_constants.BLOCK_START, value="validation_step", metadata={mllog_constants.SAMPLES_COUNT: self.val_samples_counter})

def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Optional[STEP_OUTPUT],
batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
if batch_idx % self.validation_log_interval == 0:
self.logger.end(key=mllog_constants.BLOCK_STOP, value="validation_step", metadata={mllog_constants.STEP_NUM: batch_idx})
self.logger.end(key=mllog_constants.BLOCK_STOP, value="validation_step", metadata={mllog_constants.SAMPLES_COUNT: self.val_samples_counter})


mllogger = SDLogger()
2 changes: 1 addition & 1 deletion stable_diffusion/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ diffusers==0.14.0
cloudpathlib==0.13.0
git+https://github.com/facebookresearch/xformers.git@5eb0dbf315d14b5f7b38ac2ff3d8379beca7df9b#egg=xformers
bitsandbytes==0.37.2
git+https://github.com/mlcommons/logging.git@8405a08bbfc724f8888c419461c02d55a6ac960c
git+https://github.com/ahmadki/logging.git@6328dfcd04a1bda8ae5b3e4f1f586e9037d42ddf

0 comments on commit 20391e0

Please sign in to comment.