Skip to content

Commit

Permalink
CorrDiff integration: Support wandb logging (#316)
Browse files Browse the repository at this point in the history
* Update blossom-ci.yml (#295)

* Change pip install commands with the correct PyPI package name (#298)

* add wb logging

* formatting

* make mode configurable

---------

Co-authored-by: Kaustubh Tangsali <[email protected]>
Co-authored-by: Abdullah <[email protected]>
  • Loading branch information
3 people authored Jan 25, 2024
1 parent 85c10f9 commit 24bee5c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
2 changes: 2 additions & 0 deletions examples/generative/corrdiff/conf/config_train_diffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ workers: 4


## I/O-related options
wandb_mode: offline
# Wights & biases mode [online, ofline, disabled]
desc: ''
# String to include in result dir name
tick: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ workers: 4


## I/O-related options
wandb_mode: offline
# Wights & biases mode [online, ofline, disabled]
desc: ''
# String to include in result dir name
tick: 1
Expand Down
2 changes: 2 additions & 0 deletions examples/generative/corrdiff/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def main(cfg: DictConfig) -> None:
workers = getattr(cfg, "workers", 4)

# Parse I/O-related options
wandb_mode = getattr(cfg, "wandb_mode", "disabled")
desc = getattr(cfg, "desc")
tick = getattr(cfg, "tick", 1)
snap = getattr(cfg, "snap", 1)
Expand All @@ -80,6 +81,7 @@ def main(cfg: DictConfig) -> None:
# Parse weather data options
c = EasyDict()
c.task = task
c.wandb_mode = wandb_mode
c.train_data_path = getattr(cfg, "train_data_path")
c.crop_size_x = getattr(cfg, "crop_size_x", 448)
c.crop_size_y = getattr(cfg, "crop_size_y", 448)
Expand Down
26 changes: 22 additions & 4 deletions examples/generative/corrdiff/training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import sys
import time
import wandb as wb

import numpy as np
import psutil
Expand All @@ -29,7 +30,11 @@
sys.path.append("../")
from module import Module
from modulus.distributed import DistributedManager
from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper
from modulus.launch.logging import (
PythonLogger,
RankZeroLoggingWrapper,
initialize_wandb,
)
from modulus.utils.generative import (
InfiniteSampler,
construct_class_by_name,
Expand Down Expand Up @@ -81,6 +86,7 @@ def training_loop(
gridtype="sinusoidal",
N_grid_channels=4,
normalization="v1",
wandb_mode="disabled",
):
"""CorrDiff training loop"""

Expand All @@ -93,6 +99,15 @@ def training_loop(
logger0 = RankZeroLoggingWrapper(logger, dist)
logger.file_logging(file_name=f".logs/training_loop_{dist.rank}.log")

# wandb logger
initialize_wandb(
project="Modulus-Generative",
entity="Modulus",
name="CorrDiff",
group="CorrDiff-DDP-Group",
mode=wandb_mode,
)

# Initialize.
start_time = time.time()

Expand Down Expand Up @@ -241,6 +256,7 @@ def training_loop(
while True:
# Accumulate gradients.
optimizer.zero_grad(set_to_none=True)
loss_accum = 0
for round_idx in range(num_accumulation_rounds):
with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
# Fetch training data: weather
Expand All @@ -261,13 +277,17 @@ def training_loop(
augment_pipe=augment_pipe,
)
training_stats.report("Loss/loss", loss)
loss.sum().mul(loss_scaling / batch_gpu_total).backward()
loss = loss.sum().mul(loss_scaling / batch_gpu_total)
loss_accum += loss
loss.backward()
wb.log({"loss": loss_accum}, step=cur_nimg)

# Update weights.
for g in optimizer.param_groups:
g["lr"] = optimizer_kwargs["lr"] * min(
cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1
) # TODO better handling (potential bug)
wb.log({"lr": g["lr"]}, step=cur_nimg)
for param in net.parameters():
if param.grad is not None:
torch.nan_to_num(
Expand Down Expand Up @@ -324,8 +344,6 @@ def training_loop(
torch.cuda.reset_peak_memory_stats()
logger0.info(" ".join(fields))

ckpt_dir = run_dir

# Save full dump of the training state.
if (
(state_dump_ticks is not None)
Expand Down

0 comments on commit 24bee5c

Please sign in to comment.