Skip to content

Commit

Permalink
Remove single worker limitation in deterministic mode (IntelLabs#227)
Browse files Browse the repository at this point in the history
Also:
* Single worker limitation not needed anymore, been fixed in PyTorch
  since v0.4.0 (pytorch/pytorch#4640)
* compress_classifier.py: If run in evaluation mode (--eval), enable
  deterministic mode.
* Call utils.set_deterministic at data loaders creation if
  deterministic argument is set (don't assume user calls it outside)
* Disable CUDNN benchmark mode in utils.set_deterministic
  (https://pytorch.org/docs/stable/notes/randomness.html#cudnn)
  • Loading branch information
barrh authored and guyjacob committed Apr 18, 2019
1 parent a3c8d86 commit 8c5de42
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 21 deletions.
11 changes: 10 additions & 1 deletion distiller/apputils/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
This code will help with the image classification datasets: ImageNet and CIFAR10
"""
import logging
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import Sampler
import numpy as np

import distiller


msglogger = logging.getLogger()

DATASETS_NAMES = ['imagenet', 'cifar10']


Expand Down Expand Up @@ -170,7 +176,10 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_
effective_train_size=1., effective_valid_size=1., effective_test_size=1.):
train_dataset, test_dataset = datasets_fn(data_dir)

worker_init_fn = __deterministic_worker_init_fn if deterministic else None
worker_init_fn = None
if deterministic:
distiller.set_deterministic()
worker_init_fn = __deterministic_worker_init_fn

num_train = len(train_dataset)
indices = list(range(num_train))
Expand Down
19 changes: 12 additions & 7 deletions distiller/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@
This module contains various tensor sparsity/density measurement functions, together
with some random helper functions.
"""
import inspect
import argparse
from collections import OrderedDict
from copy import deepcopy
import logging
import operator
import random

import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import random
from copy import deepcopy
import yaml
from collections import OrderedDict
import argparse
import operator

import inspect

msglogger = logging.getLogger()


def model_device(model):
Expand Down Expand Up @@ -584,10 +588,12 @@ def replace_data_parallel(container):


def set_deterministic():
msglogger.debug('set_deterministic is called')
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def yaml_ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict):
Expand Down Expand Up @@ -623,7 +629,6 @@ def checker(val_str):
return checker



def filter_kwargs(dict_to_filter, function_to_call):
"""Utility to check which arguments in the passed dictionary exist in a function's signature
Expand Down
15 changes: 7 additions & 8 deletions examples/classifier_compression/compress_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,17 @@ def main():
start_epoch = 0
ending_epoch = args.epochs
perf_scores_history = []

if args.evaluate:
args.deterministic = True
if args.deterministic:
# Experiment reproducibility is sometimes important. Pete Warden expounded about this
# in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
# In Pytorch, support for deterministic execution is still a bit clunky.
if args.workers > 1:
raise ValueError('ERROR: Setting --deterministic requires setting --workers/-j to 0 or 1')
# Use a well-known seed, for repeatability of experiments
distiller.set_deterministic()
distiller.set_deterministic() # Use a well-known seed, for repeatability of experiments
else:
# This issue: https://github.com/pytorch/pytorch/issues/3659
# Implies that cudnn.benchmark should respect cudnn.deterministic, but empirically we see that
# results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
# Turn on CUDNN benchmark mode for best performance. This is usually "safe" for image
# classification models, as the input sizes don't change during the run
# See here: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
cudnn.benchmark = True

if args.cpu or not torch.cuda.is_available():
Expand Down
2 changes: 1 addition & 1 deletion examples/classifier_compression/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_parser():
help='Flag to override optimizer if resumed from checkpoint. This will reset epochs count.')

parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
help='evaluate model on test set')
parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
help='collect activation statistics on phases: train, valid, and/or test'
' (WARNING: this slows down training)')
Expand Down
8 changes: 4 additions & 4 deletions tests/full_flow_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,16 @@ def collateral_checker(log, *collateral_list):
TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker_args'])

test_configs = [
TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [48.220, 92.930]),
TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.610, 92.080]),
TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate'.
format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
DS_CIFAR, accuracy_checker, [91.640, 99.610]),
DS_CIFAR, accuracy_checker, [91.710, 99.610]),
TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
DS_CIFAR, accuracy_checker, [54.390, 94.280]),
DS_CIFAR, accuracy_checker, [54.590, 94.810]),
TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'.
format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
DS_CIFAR, collateral_checker, [('sensitivity.csv', 3165), ('sensitivity.png', 96158)])
DS_CIFAR, collateral_checker, [('sensitivity.csv', 3175), ('sensitivity.png', 96158)])
]


Expand Down

0 comments on commit 8c5de42

Please sign in to comment.