From 61b4a5e7ecf5ba0cde443ded609fc4eb9fe95462 Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Sat, 1 Jun 2019 08:40:53 +0300 Subject: [PATCH] distributed sampler epoch set (#207) * distributed sampler improvement --- catalyst/dl/callbacks/base.py | 24 ++++++++++++------------ catalyst/dl/experiments/core.py | 8 ++++++-- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/catalyst/dl/callbacks/base.py b/catalyst/dl/callbacks/base.py index e401a8366d..f9a5aa6972 100644 --- a/catalyst/dl/callbacks/base.py +++ b/catalyst/dl/callbacks/base.py @@ -162,6 +162,18 @@ def __init__( self._optimizer_wd = 0 self._accumulation_counter = 0 + @staticmethod + def grad_step(*, optimizer, optimizer_wd=0, grad_clip_fn=None): + for group in optimizer.param_groups: + if optimizer_wd > 0: + for param in group["params"]: + param.data = param.data.add( + -optimizer_wd * group["lr"], param.data + ) + if grad_clip_fn is not None: + grad_clip_fn(group["params"]) + optimizer.step() + def on_stage_start(self, state: RunnerState): optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key @@ -179,18 +191,6 @@ def on_epoch_start(self, state): self._optimizer_wd = optimizer.param_groups[0].get("weight_decay", 0.0) optimizer.param_groups[0]["weight_decay"] = 0.0 - @staticmethod - def grad_step(*, optimizer, optimizer_wd=0, grad_clip_fn=None): - for group in optimizer.param_groups: - if optimizer_wd > 0: - for param in group["params"]: - param.data = param.data.add( - -optimizer_wd * group["lr"], param.data - ) - if grad_clip_fn is not None: - grad_clip_fn(group["params"]) - optimizer.step() - def on_batch_start(self, state): state.loss = None diff --git a/catalyst/dl/experiments/core.py b/catalyst/dl/experiments/core.py index 8cbd656cc6..b7e59f2707 100644 --- a/catalyst/dl/experiments/core.py +++ b/catalyst/dl/experiments/core.py @@ -5,6 +5,7 @@ import torch from torch import nn, optim from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DistributedSampler from catalyst.dl.state import RunnerState from catalyst.dl.utils import UtilsFactory @@ -232,12 +233,15 @@ def _run_epoch(self, loaders): assert not any(x.startswith("train") for x in loaders.keys()), \ "for inference no train loader should be passed" - for loader_name in loaders: + for loader_name, loader in loaders.items(): self.state.loader_name = loader_name - self.state.loader_len = len(loaders[loader_name]) + self.state.loader_len = len(loader) self.state.need_backward = loader_name.startswith("train") self.model.train(self.state.need_backward) + if isinstance(loader.sampler, DistributedSampler): + loader.sampler.set_epoch(self.state.stage_epoch) + self._run_event("loader_start") with torch.set_grad_enabled(self.state.need_backward): self._run_loader(loaders[loader_name])