Skip to content

Commit

Permalink
W&B: add artifacts support and fix logging steps (#1309)
Browse files Browse the repository at this point in the history
* improve wandb logger

* changelog

* Update catalyst/loggers/wandb.py

* Update catalyst/loggers/wandb.py

Co-authored-by: Sergey Kolesnikov <[email protected]>
  • Loading branch information
AyushExel and Scitator authored Sep 29, 2021
1 parent 679caf7 commit bd3a0d7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


### Changed

- Improved `WandbLogger` to support artifacts and fix logging steps ([#1309](https://github.com/catalyst-team/catalyst/pull/1309))
- full `Runner` cleanup, with callbacks and loaders destruction, moved to `PipelineParallelFairScaleEngine` only ([#1295](https://github.com/catalyst-team/catalyst/pull/1295))
- `HuberLoss` renamed to `HuberLossV0` for the PyTorch compatibility ([#1295](https://github.com/catalyst-team/catalyst/pull/1295))
- codestyle update ([#1298](https://github.com/catalyst-team/catalyst/pull/1298))
Expand Down
62 changes: 58 additions & 4 deletions catalyst/loggers/wandb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, Optional
import os
import pickle

import numpy as np

Expand Down Expand Up @@ -112,18 +114,18 @@ def log_metrics(
if scope == "batch":
metrics = {k: float(v) for k, v in metrics.items()}
self._log_metrics(
metrics=metrics, step=global_epoch_step, loader_key=loader_key, prefix="batch"
metrics=metrics, step=global_sample_step, loader_key=loader_key, prefix="batch"
)
elif scope == "loader":
self._log_metrics(
metrics=metrics, step=global_epoch_step, loader_key=loader_key, prefix="epoch"
metrics=metrics, step=global_sample_step, loader_key=loader_key, prefix="epoch"
)
elif scope == "epoch":
loader_key = "_epoch_"
per_loader_metrics = metrics[loader_key]
self._log_metrics(
metrics=per_loader_metrics,
step=global_epoch_step,
step=global_sample_step,
loader_key=loader_key,
prefix="epoch",
)
Expand Down Expand Up @@ -154,7 +156,7 @@ def log_image(
"""Logs image to the logger."""
self.run.log(
{f"{tag}_scope_{scope}_epoch_{global_epoch_step}.png": wandb.Image(image)},
step=global_epoch_step,
step=global_sample_step,
)

def log_hparams(
Expand All @@ -168,12 +170,64 @@ def log_hparams(
"""Logs hyperparameters to the logger."""
self.run.config.update(hparams)

def log_artifact(
self,
tag: str,
artifact: object = None,
path_to_artifact: str = None,
scope: str = None,
# experiment info
run_key: str = None,
global_epoch_step: int = 0,
global_batch_step: int = 0,
global_sample_step: int = 0,
# stage info
stage_key: str = None,
stage_epoch_len: int = 0,
stage_epoch_step: int = 0,
stage_batch_step: int = 0,
stage_sample_step: int = 0,
# loader info
loader_key: str = None,
loader_batch_len: int = 0,
loader_sample_len: int = 0,
loader_batch_step: int = 0,
loader_sample_step: int = 0,
) -> None:
"""Logs artifact (arbitrary file like audio, video, model weights) to the logger."""
if artifact is None and path_to_artifact is None:
ValueError("Both artifact and path_to_artifact cannot be None")

artifact = wandb.Artifact(
name=self.run.id + "_aritfacts",
type="artifact",
metadata={
"stage_key": stage_key,
"loader_key": loader_key,
"scope": scope,
},
)

if artifact:
art_file_dir = os.path.join("wandb", self.run.id, "artifact_dumps")
os.makedirs(art_file_dir, exist_ok=True)

art_file = open(os.path.join(art_file_dir, tag), "wb")
pickle.dump(artifact, art_file)
art_file.close()

artifact.add_file(str(os.path.join(art_file_dir, tag)))
else:
artifact.add_file(path_to_artifact)
self.run.log_artifact(artifact)

def flush_log(self) -> None:
"""Flushes the logger."""
pass

def close_log(self, scope: str = None) -> None:
"""Closes the logger."""
# Artifacts can be logged after call to close_log()
if scope is None or scope == "experiment":
self.run.finish()

Expand Down

0 comments on commit bd3a0d7

Please sign in to comment.