diff --git a/stable_diffusion/mlperf_logging_utils.py b/stable_diffusion/mlperf_logging_utils.py index b49c8d959..ff5607e52 100644 --- a/stable_diffusion/mlperf_logging_utils.py +++ b/stable_diffusion/mlperf_logging_utils.py @@ -83,6 +83,8 @@ def __init__(self, logger, train_log_interval=5, validation_log_interval=1): self.logger = mllogger self.train_log_interval = train_log_interval self.validation_log_interval = validation_log_interval + self.train_samples_counter = 0 + self.val_samples_counter = 0 def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.logger.start(mllog_constants.RUN_START) @@ -105,46 +107,53 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int) -> None: if trainer.global_step % self.train_log_interval == 0: - self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step}) + self.logger.start(key=mllog_constants.BLOCK_START, value="training_step", metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter}) def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: + batch_size = len(batch['txt']) + total_batch_sizes = pl_module.all_gather(batch_size).sum().item() # Gather batch sizes from all devices + self.train_samples_counter += total_batch_sizes + if trainer.global_step % self.train_log_interval == 0: logs = trainer.callback_metrics - self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step}) - self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.STEP_NUM: trainer.global_step}) - self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.STEP_NUM: trainer.global_step}) + self.logger.event(key="loss", value=logs["train/loss"].item(), metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter}) + self.logger.event(key="lr_abs", value=logs["lr_abs"].item(), metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter}) + self.logger.end(key=mllog_constants.BLOCK_STOP, value="training_step", metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter}) def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.logger.start(key=mllog_constants.EVAL_START, value=trainer.global_step) + self.logger.start(key=mllog_constants.EVAL_START, value=self.train_samples_counter, metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter}) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: logs = trainer.callback_metrics if "validation/fid" in logs: self.logger.event(key=mllog_constants.EVAL_ACCURACY, value=logs["validation/fid"].item(), - metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "FID"}) + metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter, "metric": "FID"}) if "validation/clip" in logs: self.logger.event(key=mllog_constants.EVAL_ACCURACY, value=logs["validation/clip"].item(), - metadata={mllog_constants.STEP_NUM: trainer.global_step, "metric": "CLIP"}) - self.logger.end(key=mllog_constants.EVAL_STOP, value=trainer.global_step) + metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter, "metric": "CLIP"}) + self.logger.end(key=mllog_constants.EVAL_STOP, value=self.train_samples_counter, metadata={mllog_constants.SAMPLES_COUNT: self.train_samples_counter}) def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - pass + self.val_samples_counter = 0 def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: pass def on_validation_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + batch_size = len(batch['caption']) + total_batch_sizes = pl_module.all_gather(batch_size).sum().item() # Gather batch sizes from all devices + self.val_samples_counter += total_batch_sizes if batch_idx % self.validation_log_interval == 0: - self.logger.start(key=mllog_constants.BLOCK_START, value="validation_step", metadata={mllog_constants.STEP_NUM: batch_idx}) + self.logger.start(key=mllog_constants.BLOCK_START, value="validation_step", metadata={mllog_constants.SAMPLES_COUNT: self.val_samples_counter}) def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: if batch_idx % self.validation_log_interval == 0: - self.logger.end(key=mllog_constants.BLOCK_STOP, value="validation_step", metadata={mllog_constants.STEP_NUM: batch_idx}) + self.logger.end(key=mllog_constants.BLOCK_STOP, value="validation_step", metadata={mllog_constants.SAMPLES_COUNT: self.val_samples_counter}) mllogger = SDLogger() diff --git a/stable_diffusion/requirements.txt b/stable_diffusion/requirements.txt index 26802b5ff..8ed4e9675 100644 --- a/stable_diffusion/requirements.txt +++ b/stable_diffusion/requirements.txt @@ -22,4 +22,4 @@ diffusers==0.14.0 cloudpathlib==0.13.0 git+https://github.com/facebookresearch/xformers.git@5eb0dbf315d14b5f7b38ac2ff3d8379beca7df9b#egg=xformers bitsandbytes==0.37.2 -git+https://github.com/mlcommons/logging.git@8405a08bbfc724f8888c419461c02d55a6ac960c +git+https://github.com/ahmadki/logging.git@6328dfcd04a1bda8ae5b3e4f1f586e9037d42ddf