Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

working prototype of wandb #271

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,27 @@ config:
# if we want to log the model parameters every few iterations, set the iteration
# frequency. -1 means the params will be logged only at the end of epochs.
LOG_PARAMS_EVERY_N_ITERS: 310

# ----------------------------------------------------------------------------------- #
# Weights and Biases (visualization)
# ----------------------------------------------------------------------------------- #
WANDB_SETUP:
# whether to use wandb for the visualization
USE_WANDB: False
# log directory for wandb events
LOG_DIR: "."
EXPERIMENT_LOG_DIR: "wandb"
# name of project
PROJECT_NAME: "vissl"
# name of specific runs
EXP_NAME: "??"
# whether to log the model parameters to tensorboard
LOG_PARAMS: True
# whether to log the model parameters gradients to tensorboard
LOG_PARAMS_GRADIENTS: True
# if we want to log the model parameters every few iterations, set the iteration
# frequency. -1 means the params will be logged only at the end of epochs.
LOG_PARAMS_EVERY_N_ITERS: 310

# ----------------------------------------------------------------------------------- #
# DATA
Expand Down
10 changes: 10 additions & 0 deletions vissl/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from vissl.hooks.tensorboard_hook import SSLTensorboardHook # noqa
from vissl.utils.tensorboard import get_tensorboard_hook, is_tensorboard_available

from vissl.hooks.wandb_hook import SSLWandbHook # noqa
from vissl.utils.wandb import get_wandb_hook, is_wandb_available


class SSLClassyHookFunctions(Enum):
"""
Expand Down Expand Up @@ -115,6 +118,13 @@ def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
)
tb_hook = get_tensorboard_hook(cfg)
hooks.extend([tb_hook])
if cfg.HOOKS.WANDB_SETUP.USE_WANDB:
assert is_wandb_available(), (
"WandB must be installed to use it. Please install WandB using:"
"If pip environment: `pip install wandb` "
)
wandb_hook = get_wandb_hook(cfg)
hooks.extend([wandb_hook])
if cfg.MODEL.GRAD_CLIP.USE_GRAD_CLIP:
hooks.extend(
[
Expand Down
226 changes: 226 additions & 0 deletions vissl/hooks/wandb_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging

import torch
from classy_vision import tasks
from classy_vision.generic.distributed_util import is_primary
from classy_vision.hooks.classy_hook import ClassyHook

if is_primary():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do this check? Importing on all ranks should be fine too?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb import has some overhead to it, so I was trying to limit it to the main worker since it's the only one using it. This can be changed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to me to do it on primary_rank only. :)

import wandb

BYTE_TO_MiB = 2 ** 20

class SSLWandbHook(ClassyHook):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question on this: is this hook similar to Tensorboard hook with the only difference being in logging to "wandb" instead of tensorboard?

Is it possible that we can inherit the TensorboardHook ? Or alternatively, does it make sense to extend the TensorboardHook directly to optionally log to wandb as well is user is using WandB ?


on_loss_and_meter = ClassyHook._noop
on_backward = ClassyHook._noop
on_start = ClassyHook._noop
on_end = ClassyHook._noop
on_step = ClassyHook._noop

def __init__(
self,
log_params: bool = False,
log_params_every_n_iterations: int = -1,
log_params_gradients: bool = False,
) -> None:
"""The constructor method of SSLWandbHook.

Args:
log_params (bool): whether to log model params to tensorboard
log_params_every_n_iterations (int): frequency at which parameters
should be logged to tensorboard
log_params_gradients (bool): whether to log params gradients as well
to tensorboard.
"""
super().__init__()
# going to assume WandB install check is already performed (TODO: check this)

logging.info("Setting up SSL Wandb Hook...")
self.watched = False
self.log_params = log_params
self.log_params_every_n_iterations = log_params_every_n_iterations
self.log_params_gradients = log_params_gradients
logging.info(
f"Wandb config: log_params: {self.log_params}, "
f"log_params_freq: {self.log_params_every_n_iterations}, "
f"log_params_gradients: {self.log_params_gradients}"
)


def on_forward(self, task: "tasks.ClassyTask") -> None:
"""
Called after every forward if tensorboard hook is enabled.
prigoyal marked this conversation as resolved.
Show resolved Hide resolved
Logs the model parameters if the training iteration matches the
logging frequency.
"""
if not self.log_params:
return

if (
self.log_params_every_n_iterations > 0
and is_primary()
and task.train
and task.iteration % self.log_params_every_n_iterations == 0
):
out_dict = {}
for name, parameter in task.base_model.named_parameters():
parameter = parameter.cpu().data.numpy()
out_dict[f"Parameters/{name}"] = wandb.Histogram(parameter)

wandb.log(out_dict, step=task.iteration)

def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""
Called at the start of every epoch if the tensorboard hook is
enabled.
Logs the model parameters once at the beginning of training only.
"""
if not self.log_params:
return

# log the parameters just once, before training starts
if is_primary() and task.train and task.train_phase_idx == 0:
out_dict = {}
for name, parameter in task.base_model.named_parameters():
parameter = parameter.cpu().data.numpy()
out_dict[f"Parameters/{name}"] = wandb.Histogram(parameter)

wandb.log(out_dict, step=task.iteration)

def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""
Called at the end of every epoch if the tensorboard hook is
enabled.
Log model parameters and/or parameter gradients as set by user
in the tensorboard configuration. Also resents the CUDA memory counter.
"""
out_dict = {}

# Log train/test accuracy
if is_primary():
phase_type = "Training" if task.train else "Testing"
for meter in task.meters:
if "accuracy" in meter.name:
for top_n, accuracies in meter.value.items():
for i, acc in accuracies.items():
tag_name = f"{phase_type}/Accuracy_" f" {top_n}_Output_{i}"
out_name[tag_name] = round(acc, 5)

if not (self.log_params or self.log_params_gradients):
if len(out_dict) > 0:
wandb.log(out_dict, step=task.iteration)
return

if is_primary() and task.train:
# Log the weights and bias at the end of the epoch
if self.log_params:
for name, parameter in task.base_model.named_parameters():
parameter = parameter.cpu().data.numpy()
out_dict[f"Parameters/{name}"] = wandb.Histogram(parameter)

# Log the parameter gradients at the end of the epoch
if self.log_params_gradients:
for name, parameter in task.base_model.named_parameters():
if parameter.grad is not None:
try:
parameter = parameter.grad.cpu().data.numpy()
out_dict[f"Gradients/{name}"] = wandb.Histogram(parameter)
except ValueError:
logging.info(
f"Gradient histogram empty for {name}, "
f"iteration {task.iteration}. Unable to "
f"log gradient."
)

# Reset the GPU Memory counter
if torch.cuda.is_available():
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()

wandb.log(out_dict, step=task.iteration)


def on_update(self, task: "tasks.ClassyTask") -> None:
"""
Called after every parameters update if tensorboard hook is enabled.
Logs the parameter gradients if they are being set to log,
log the scalars like training loss, learning rate, average training
iteration time, batch size per gpu, img/sec/gpu, ETA, gpu memory used,
peak gpu memory used.
"""

if not is_primary():
return

out_dict = {}
iteration = task.iteration

if (
self.log_params_every_n_iterations > 0
and self.log_params_gradients
and task.train
and iteration % self.log_params_every_n_iterations == 0
):
logging.info(f"Logging Parameter gradients. Iteration {iteration}")
for name, parameter in task.base_model.named_parameters():
if parameter.grad is not None:
try:
parameter = parameter.grad.cpu().data.numpy()
out_dict[f"Gradients/{name}"] = wandb.Histogram(parameter)
except ValueError:
logging.info(
f"Gradient histogram empty for {name}, "
f"iteration {task.iteration}. Unable to "
f"log gradient."
)

if iteration % task.config["LOG_FREQUENCY"] == 0 or (
iteration <= 100 and iteration % 5 == 0
):
logging.info(f"Logging metrics. Iteration {iteration}")
out_dict["Training/Loss"] = round(task.last_batch.loss.data.cpu().item(), 5)
out_dict["Training/Learning_rate"] = round(task.optimizer.options_view.lr, 5)

# Batch processing time
if len(task.batch_time) > 0:
batch_times = task.batch_time
else:
batch_times = [0]

batch_time_avg_s = sum(batch_times) / max(len(batch_times), 1)
out_dict["Speed/Batch_processing_time_ms"] = scalar_value=int(1000.0 * batch_time_avg_s)

# Images per second per replica
pic_per_batch_per_gpu = task.config["DATA"]["TRAIN"][
"BATCHSIZE_PER_REPLICA"
]
pic_per_batch_per_gpu_per_sec = (
int(pic_per_batch_per_gpu / batch_time_avg_s)
if batch_time_avg_s > 0
else 0.0
)
out_dict["Speed/img_per_sec_per_gpu"] = pic_per_batch_per_gpu_per_sec

# ETA
avg_time = sum(batch_times) / len(batch_times)
eta_secs = avg_time * (task.max_iteration - iteration)
out_dict["Speed/ETA_hours"] = eta_secs / 3600.0

# GPU Memory
if torch.cuda.is_available():
# Memory actually being used
out_dict["Memory/Peak_GPU_Memory_allocated_MiB"] = \
torch.cuda.max_memory_allocated() / BYTE_TO_MiB

# Memory reserved by PyTorch's memory allocator
out_dict["Memory/Peak_GPU_Memory_reserved_MiB"] = \
torch.cuda.max_memory_reserved() / BYTE_TO_MiB # byte to MiB

out_dict["Memory/Current_GPU_Memory_reserved_MiB"] = \
torch.cuda.memory_reserved() / BYTE_TO_MiB # byte to MiB

if len(out_dict) > 0:
wandb.log(out_dict, step=iteration)
Loading