Skip to content

Commit

Permalink
checkpoint.py: non-functional code refactoring
Browse files Browse the repository at this point in the history
Rearranged the code for easier reading and maintenance
  • Loading branch information
nzmora committed Aug 22, 2019
1 parent bdafebe commit 9912435
Showing 1 changed file with 61 additions and 54 deletions.
115 changes: 61 additions & 54 deletions distiller/apputils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def save_checkpoint(epoch, arch, model, optimizer=None, scheduler=None,
checkpoint['quantizer_metadata'] = model.quantizer_metadata

checkpoint['extras'] = extras

torch.save(checkpoint, fullpath)
if is_best:
shutil.copyfile(fullpath, fullpath_best)
Expand All @@ -101,8 +100,8 @@ def inspect_val(val):
return tabulate(contents, headers=["Key", "Type", "Value"], tablefmt="psql")


def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
lean_checkpoint=False, strict=False):
def load_checkpoint(model, chkpt_file, optimizer=None,
model_device=None, lean_checkpoint=False, strict=False):
"""Load a pytorch training checkpoint.
Args:
Expand All @@ -114,6 +113,52 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
This should be set to either 'cpu' or 'cuda'.
:returns: updated model, compression_scheduler, optimizer, start_epoch
"""
def _load_compression_scheduler():
normalize_keys = False
try:
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys)
except KeyError as e:
# A very common source of this KeyError is loading a GPU model on the CPU.
# We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
normalize_keys = True
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_keys)
msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
checkpoint_epoch))
return normalize_keys

def _load_and_execute_thinning_recipes():
msglogger.info("Loaded a thinning recipe from the checkpoint")
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys:
model.thinning_recipes = [distiller.get_normalized_recipe(recipe)
for recipe in model.thinning_recipes]
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)

def _load_optimizer():
"""Initialize optimizer with model parameters and load src_state_dict"""
try:
cls, src_state_dict = checkpoint['optimizer_type'], checkpoint['optimizer_state_dict']
# Initialize the dest_optimizer with a dummy learning rate,
# this is required to support SGD.__init__()
dest_optimizer = cls(model.parameters(), lr=1)
dest_optimizer.load_state_dict(src_state_dict)
msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format(
type=type(dest_optimizer)))
optimizer_param_groups = dest_optimizer.state_dict()['param_groups']
msglogger.info('Optimizer Args: {}'.format(
dict((k, v) for k, v in optimizer_param_groups[0].items()
if k != 'params')))
return dest_optimizer
except KeyError:
# Older checkpoints do support optimizer loading: They either had an 'optimizer' field
# (different name) which was not used during the load, or they didn't even checkpoint
# the optimizer.
msglogger.warning('Optimizer could not be loaded from checkpoint.')
return None

if not os.path.isfile(chkpt_file):
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)

Expand All @@ -133,30 +178,15 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
normalize_dataparallel_keys = False
if 'compression_sched' in checkpoint:
compression_scheduler = distiller.CompressionScheduler(model)
try:
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
except KeyError as e:
# A very common source of this KeyError is loading a GPU model on the CPU.
# We rename all of the DataParallel keys because DataParallel does not execute on the CPU.
normalize_dataparallel_keys = True
compression_scheduler.load_state_dict(checkpoint['compression_sched'], normalize_dataparallel_keys)
msglogger.info("Loaded compression schedule from checkpoint (epoch {})".format(
checkpoint_epoch))
normalize_dataparallel_keys = _load_compression_scheduler()
else:
msglogger.info("Warning: compression schedule data does not exist in the checkpoint")

if 'thinning_recipes' in checkpoint:
if 'compression_sched' not in checkpoint:
msglogger.warning("Found thinning_recipes key, but missing mandatory key compression_sched")
if not compression_scheduler:
msglogger.warning("Found thinning_recipes key, but missing key compression_scheduler")
compression_scheduler = distiller.CompressionScheduler(model)
msglogger.info("Loaded a thinning recipe from the checkpoint")
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys:
model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes]
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
_load_and_execute_thinning_recipes()

if 'quantizer_metadata' in checkpoint:
msglogger.info('Loaded quantizer metadata from the checkpoint')
Expand All @@ -165,49 +195,26 @@ def load_checkpoint(model, chkpt_file, optimizer=None, model_device=None, *,
quantizer.prepare_model(qmd['dummy_input'])

if normalize_dataparallel_keys:
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
if anomalous_keys:
# This is pytorch 1.1+
missing_keys, unexpected_keys = anomalous_keys
if unexpected_keys:
msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" % (chkpt_file, len(unexpected_keys)))
msglogger.warning("Warning: the loaded checkpoint (%s) contains %d unexpected state keys" %
(chkpt_file, len(unexpected_keys)))
if missing_keys:
raise ValueError("The loaded checkpoint (%s) is missing %d state keys" % (chkpt_file, len(missing_keys)))

raise ValueError("The loaded checkpoint (%s) is missing %d state keys" %
(chkpt_file, len(missing_keys)))

if model_device is not None:
model.to(model_device)

if lean_checkpoint:
msglogger.info("=> loaded 'state_dict' from checkpoint '{}'".format(str(chkpt_file)))
return (model, None, None, 0)

def _load_optimizer(cls, src_state_dict, model):
"""Initiate optimizer with model parameters and load src_state_dict"""
# initiate the dest_optimizer with a dummy learning rate,
# this is required to support SGD.__init__()
dest_optimizer = cls(model.parameters(), lr=1)
dest_optimizer.load_state_dict(src_state_dict)
return dest_optimizer

try:
optimizer = _load_optimizer(checkpoint['optimizer_type'],
checkpoint['optimizer_state_dict'], model)
except KeyError:
# Older checkpoints do support optimizer loading: They either had an 'optimizer' field
# (different name) which was not used during the load, or they didn't even checkpoint
# the optimizer.
optimizer = None

if optimizer is not None:
msglogger.info('Optimizer of type {type} was loaded from checkpoint'.format(
type=type(optimizer)))
msglogger.info('Optimizer Args: {}'.format(
dict((k,v) for k,v in optimizer.state_dict()['param_groups'][0].items()
if k != 'params')))
else:
msglogger.warning('Optimizer could not be loaded from checkpoint.')
return model, None, None, 0

optimizer = _load_optimizer()
msglogger.info("=> loaded checkpoint '{f}' (epoch {e})".format(f=str(chkpt_file),
e=checkpoint_epoch))
return (model, compression_scheduler, optimizer, start_epoch)
return model, compression_scheduler, optimizer, start_epoch

0 comments on commit 9912435

Please sign in to comment.