Skip to content

Commit

Permalink
apputils/checkpoint.py: load_checkpoint can be called w/o specifying …
Browse files Browse the repository at this point in the history
…the model

This is inspired by @barrh’s PR IntelLabs#246
but it at a “slower-integration-pace” and w/o changing APIs.

1. create_model() attaches model attributes (arch, dataset, is_parallel) to created models.
2. save_checkpoint() stores the new model attributes with checkpoint metadata
3. load_checkpoint() can be invoked with model=None, in which case we attempt
to create the model from the stored checkpoint metadata.
  • Loading branch information
nzmora committed Aug 22, 2019
1 parent 9912435 commit b41c4d2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 13 deletions.
42 changes: 36 additions & 6 deletions distiller/apputils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,15 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
filename_best = 'best.pth.tar' if name is None else name + '_best.pth.tar'
fullpath_best = os.path.join(dir, filename_best)

checkpoint = {}
checkpoint['epoch'] = epoch
checkpoint['arch'] = arch
checkpoint['state_dict'] = model.state_dict()
checkpoint = {'epoch': epoch, 'state_dict': model.state_dict(), 'arch': arch}
try:
checkpoint['is_parallel'] = model.is_parallel
checkpoint['dataset'] = model.dataset
if not arch:
checkpoint['arch'] = model.arch
except NameError:
pass

if optimizer is not None:
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['optimizer_type'] = type(optimizer)
Expand Down Expand Up @@ -105,7 +110,10 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
"""Load a pytorch training checkpoint.
Args:
model: the pytorch model to which we will load the parameters
model: the pytorch model to which we will load the parameters. You can
specify model=None if the checkpoint contains enough metadata to infer
the model. The order of the arguments is misleading and clunky, and is
kept this way for backward compatibility.
chkpt_file: the checkpoint file
lean_checkpoint: if set, read into model only 'state_dict' field
optimizer: [deprecated argument]
Expand Down Expand Up @@ -159,8 +167,24 @@ def _load_optimizer():
msglogger.warning('Optimizer could not be loaded from checkpoint.')
return None

def _create_model_from_ckpt():
try:
return distiller.models.create_model(False, checkpoint['dataset'], checkpoint['arch'],
checkpoint['is_parallel'], device_ids=None)
except KeyError:
return None

def _sanity_check():
try:
if model.arch != checkpoint["arch"]:
raise ValueError("The model architecture does not match the checkpoint architecture")
except (NameError, KeyError):
# One of the values is missing so we can't perform the comparison
pass

if not os.path.isfile(chkpt_file):
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
assert optimizer == None, "argument optimizer is deprecated and must be set to None"

msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage)
Expand All @@ -171,9 +195,14 @@ def _load_optimizer():
if 'state_dict' not in checkpoint:
raise ValueError("Checkpoint must contain the model parameters under the key 'state_dict'")

if not model:
model = _create_model_from_ckpt()
if not model:
raise ValueError("You didn't provide a model, and the checkpoint doesn't contain"
"enough information to create one")

checkpoint_epoch = checkpoint.get('epoch', None)
start_epoch = checkpoint_epoch + 1 if checkpoint_epoch is not None else 0

compression_scheduler = None
normalize_dataparallel_keys = False
if 'compression_sched' in checkpoint:
Expand Down Expand Up @@ -217,4 +246,5 @@ def _load_optimizer():
optimizer = _load_optimizer()
msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
e=checkpoint_epoch))
_sanity_check()
return model, compression_scheduler, optimizer, start_epoch
13 changes: 9 additions & 4 deletions distiller/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,19 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
arch, dataset))
if torch.cuda.is_available() and device_ids != -1:
device = 'cuda'
if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel:
model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
elif parallel:
model = torch.nn.DataParallel(model, device_ids=device_ids)
if parallel:
if arch.startswith('alexnet') or arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
else:
model = torch.nn.DataParallel(model, device_ids=device_ids)
else:
device = 'cpu'

# Cache some attributes which describe the model
_set_model_input_shape_attr(model, arch, dataset, pretrained, cadene)
model.arch = arch
model.dataset = dataset
model.is_parallel = parallel
return model.to(device)


Expand Down
33 changes: 30 additions & 3 deletions tests/test_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import logging
import tempfile

import torch
import pytest
import distiller
Expand Down Expand Up @@ -182,8 +182,8 @@ def test_load_gpu_model_on_cpu_with_thinning():
distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None)
assert hasattr(gpu_model, 'thinning_recipes')
scheduler = distiller.CompressionScheduler(gpu_model)
save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None,
dir='checkpoints')
save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model,
scheduler=scheduler, optimizer=None, dir='checkpoints')

CPU_DEVICE_ID = -1
cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
Expand Down Expand Up @@ -269,3 +269,30 @@ def check_shape_device(t, exp_shape, exp_device):
check_shape_device(t[0], shape[0], expected_device)
check_shape_device(t[1][0], shape[1][0], expected_device)
check_shape_device(t[1][1], shape[1][1], expected_device)


def test_load_checkpoint_without_model():
checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar'
# Load a checkpoint w/o specifying the model: this should fail because the loaded
# checkpoint is old and does not have the required metadata to create a model.
with pytest.raises(ValueError):
load_checkpoint(model=None, chkpt_file=checkpoint_filename)

for model_device in (None, 'cuda', 'cpu'):
# Now we create a new model, save a checkpoint, and load it w/o specifying the model.
# This should succeed because the checkpoint has enough metadata to create model.
model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, checkpoint_filename)
save_checkpoint(epoch=0, arch='resnet20_cifar', model=model, name='eraseme',
scheduler=compression_scheduler, optimizer=None, dir='checkpoints')
temp_checkpoint = os.path.join("checkpoints", "eraseme_checkpoint.pth.tar")
model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model=None,
chkpt_file=temp_checkpoint,
model_device=model_device)
assert compression_scheduler is not None
assert optimizer is None
assert start_epoch == 1
assert model
assert model.arch == "resnet20_cifar"
assert model.dataset == "cifar10"
os.remove(temp_checkpoint)

0 comments on commit b41c4d2

Please sign in to comment.