diff --git a/examples/generative/corrdiff/conf/config_train_diffusion.yaml b/examples/generative/corrdiff/conf/config_train_diffusion.yaml index f03d8c276f..a77d508a65 100644 --- a/examples/generative/corrdiff/conf/config_train_diffusion.yaml +++ b/examples/generative/corrdiff/conf/config_train_diffusion.yaml @@ -61,6 +61,8 @@ workers: 4 ## I/O-related options +wandb_mode: offline + # Wights & biases mode [online, ofline, disabled] desc: '' # String to include in result dir name tick: 1 diff --git a/examples/generative/corrdiff/conf/config_train_regression.yaml b/examples/generative/corrdiff/conf/config_train_regression.yaml index 80877a8f3a..45ee7bbebe 100644 --- a/examples/generative/corrdiff/conf/config_train_regression.yaml +++ b/examples/generative/corrdiff/conf/config_train_regression.yaml @@ -61,6 +61,8 @@ workers: 4 ## I/O-related options +wandb_mode: offline + # Wights & biases mode [online, ofline, disabled] desc: '' # String to include in result dir name tick: 1 diff --git a/examples/generative/corrdiff/train.py b/examples/generative/corrdiff/train.py index 81ef9f72ad..0f9eb63343 100644 --- a/examples/generative/corrdiff/train.py +++ b/examples/generative/corrdiff/train.py @@ -69,6 +69,7 @@ def main(cfg: DictConfig) -> None: workers = getattr(cfg, "workers", 4) # Parse I/O-related options + wandb_mode = getattr(cfg, "wandb_mode", "disabled") desc = getattr(cfg, "desc") tick = getattr(cfg, "tick", 1) snap = getattr(cfg, "snap", 1) @@ -80,6 +81,7 @@ def main(cfg: DictConfig) -> None: # Parse weather data options c = EasyDict() c.task = task + c.wandb_mode = wandb_mode c.train_data_path = getattr(cfg, "train_data_path") c.crop_size_x = getattr(cfg, "crop_size_x", 448) c.crop_size_y = getattr(cfg, "crop_size_y", 448) diff --git a/examples/generative/corrdiff/training/training_loop.py b/examples/generative/corrdiff/training/training_loop.py index e58a3fc3ee..8b68076183 100644 --- a/examples/generative/corrdiff/training/training_loop.py +++ b/examples/generative/corrdiff/training/training_loop.py @@ -19,6 +19,7 @@ import os import sys import time +import wandb as wb import numpy as np import psutil @@ -29,7 +30,11 @@ sys.path.append("../") from module import Module from modulus.distributed import DistributedManager -from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper +from modulus.launch.logging import ( + PythonLogger, + RankZeroLoggingWrapper, + initialize_wandb, +) from modulus.utils.generative import ( InfiniteSampler, construct_class_by_name, @@ -81,6 +86,7 @@ def training_loop( gridtype="sinusoidal", N_grid_channels=4, normalization="v1", + wandb_mode="disabled", ): """CorrDiff training loop""" @@ -93,6 +99,15 @@ def training_loop( logger0 = RankZeroLoggingWrapper(logger, dist) logger.file_logging(file_name=f".logs/training_loop_{dist.rank}.log") + # wandb logger + initialize_wandb( + project="Modulus-Generative", + entity="Modulus", + name="CorrDiff", + group="CorrDiff-DDP-Group", + mode=wandb_mode, + ) + # Initialize. start_time = time.time() @@ -241,6 +256,7 @@ def training_loop( while True: # Accumulate gradients. optimizer.zero_grad(set_to_none=True) + loss_accum = 0 for round_idx in range(num_accumulation_rounds): with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): # Fetch training data: weather @@ -261,13 +277,17 @@ def training_loop( augment_pipe=augment_pipe, ) training_stats.report("Loss/loss", loss) - loss.sum().mul(loss_scaling / batch_gpu_total).backward() + loss = loss.sum().mul(loss_scaling / batch_gpu_total) + loss_accum += loss + loss.backward() + wb.log({"loss": loss_accum}, step=cur_nimg) # Update weights. for g in optimizer.param_groups: g["lr"] = optimizer_kwargs["lr"] * min( cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1 ) # TODO better handling (potential bug) + wb.log({"lr": g["lr"]}, step=cur_nimg) for param in net.parameters(): if param.grad is not None: torch.nan_to_num( @@ -324,8 +344,6 @@ def training_loop( torch.cuda.reset_peak_memory_stats() logger0.info(" ".join(fields)) - ckpt_dir = run_dir - # Save full dump of the training state. if ( (state_dump_ticks is not None)