Skip to content

Commit

Permalink
distributed sampler epoch set (#207)
Browse files Browse the repository at this point in the history
* distributed sampler improvement
  • Loading branch information
Scitator authored Jun 1, 2019
1 parent 5e3afe5 commit 61b4a5e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
24 changes: 12 additions & 12 deletions catalyst/dl/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ def __init__(
self._optimizer_wd = 0
self._accumulation_counter = 0

@staticmethod
def grad_step(*, optimizer, optimizer_wd=0, grad_clip_fn=None):
for group in optimizer.param_groups:
if optimizer_wd > 0:
for param in group["params"]:
param.data = param.data.add(
-optimizer_wd * group["lr"], param.data
)
if grad_clip_fn is not None:
grad_clip_fn(group["params"])
optimizer.step()

def on_stage_start(self, state: RunnerState):
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
Expand All @@ -179,18 +191,6 @@ def on_epoch_start(self, state):
self._optimizer_wd = optimizer.param_groups[0].get("weight_decay", 0.0)
optimizer.param_groups[0]["weight_decay"] = 0.0

@staticmethod
def grad_step(*, optimizer, optimizer_wd=0, grad_clip_fn=None):
for group in optimizer.param_groups:
if optimizer_wd > 0:
for param in group["params"]:
param.data = param.data.add(
-optimizer_wd * group["lr"], param.data
)
if grad_clip_fn is not None:
grad_clip_fn(group["params"])
optimizer.step()

def on_batch_start(self, state):
state.loss = None

Expand Down
8 changes: 6 additions & 2 deletions catalyst/dl/experiments/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DistributedSampler

from catalyst.dl.state import RunnerState
from catalyst.dl.utils import UtilsFactory
Expand Down Expand Up @@ -232,12 +233,15 @@ def _run_epoch(self, loaders):
assert not any(x.startswith("train") for x in loaders.keys()), \
"for inference no train loader should be passed"

for loader_name in loaders:
for loader_name, loader in loaders.items():
self.state.loader_name = loader_name
self.state.loader_len = len(loaders[loader_name])
self.state.loader_len = len(loader)
self.state.need_backward = loader_name.startswith("train")
self.model.train(self.state.need_backward)

if isinstance(loader.sampler, DistributedSampler):
loader.sampler.set_epoch(self.state.stage_epoch)

self._run_event("loader_start")
with torch.set_grad_enabled(self.state.need_backward):
self._run_loader(loaders[loader_name])
Expand Down

0 comments on commit 61b4a5e

Please sign in to comment.