Skip to content

Commit

Permalink
Fix non 1:1 mapping between model w. ModuleList and SummaryGraph (Int…
Browse files Browse the repository at this point in the history
…elLabs#328)

The PyTorch trace mechanism doesn't "see" torch.nn.ModuleList modules
(since they don't have a forward function). As a result, the mapping
from module names at the Python model definition level to the
scope-names at the trace level is not 1:1. This makes it impossible for
us to map back from SummaryGraph ops to their respective nn.Modules,
which is required for flows like BatchNorm folding and stats fusion in
post-training quantization.

In IntelLabs#313 we handled this issue specifically in DistillerLSTM, but it
makes much more sense to have a generic and automatic solution for this
issue, which doesn't require the user to modify the model. This is such
a solution.
    
* Implemented DistillerModuleList, a replacement for nn.ModuleList
  which results in full and unique scope-names
* See documentation for this class in summary_graph.py for extensive
  details on the issue and solution
* When generating a SummaryGraph, the model is scanned and all instances
  of torch.nn.ModuleList are replaced with DistillerModulelist
* Add tests for new functionality
* Partially revert changes made to DistillerLSTM in commit 43548de:
  Keep the refactored _create_cells_list function, but have it create
  a standard torch.nn.ModuleList (since we're the ModuleList issue
  automatically now, and no need to confuse users with ad-hoc list 
  implementations
  • Loading branch information
guyjacob authored Jul 22, 2019
1 parent 4ec96d9 commit b614330
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 56 deletions.
27 changes: 8 additions & 19 deletions distiller/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,45 +201,34 @@ def __init__(self, input_size, hidden_size, num_layers, bias=True, batch_first=F
# # Process each timestep at the entire layers chain -
# # each timestep is forwarded through `front` and `back` chains independently,
# # similarily to a unidirectional LSTM.
# self.cells = self._create_cells_list('cell', 1)
# self.cells_reverse = self._create_cells_list('cell_reverse', 2)
# self.cells = self._create_cells_list(1)
# self.cells_reverse = self._create_cells_list(2)
# self.forward_fn = self.process_layer_wise
# self.layer_chain_fn = self._layer_chain_bidirectional_type1

elif bidirectional_type == 2:
# Process the entire sequence at each layer consecutively -
# the output of one layer is the sequence processed through the `front` and `back` cells
# and the input to the next layers are both `output_front` and `output_back`.
self.cells = self._create_cells_list('cell', 2)
self.cells_reverse = self._create_cells_list('cell_reverse', 2)
self.cells = self._create_cells_list(2)
self.cells_reverse = self._create_cells_list(2)
self.forward_fn = self._bidirectional_type2_forward

else:
raise ValueError("The only allowed types are [1, 2].")
else:
self.cells = self._create_cells_list('cell')
self.cells = self._create_cells_list()
self.forward_fn = self.process_layer_wise
self.layer_chain_fn = self._layer_chain_unidirectional

self.dropout = nn.Dropout(dropout)
self.dropout_factor = dropout

def _create_cells_list(self, name, hidden_size_scale=1):
# We don't use a ModuleList, because they don't show up properly as scope names when creating a trace.
# That makes it impossible to map back from the trace to the actual module, which in turn means that
# mechanisms that rely on understanding modules connectivity won't work (such as fusions in post-training
# quantization).
#
# So, we register each cell manually and just store them in a vanilla list

def _create_cells_list(self, hidden_size_scale=1):
# We always have the first layer
c = DistillerLSTMCell(self.input_size, self.hidden_size, self.bias)
setattr(self, name + '_0', c)
cells = [c]
cells = nn.ModuleList([DistillerLSTMCell(self.input_size, self.hidden_size, self.bias)])
for i in range(1, self.num_layers):
c = DistillerLSTMCell(hidden_size_scale * self.hidden_size, self.hidden_size, self.bias)
setattr(self, '{}_{}'.format(name, i), c)
cells.append(c)
cells.append(DistillerLSTMCell(hidden_size_scale * self.hidden_size, self.hidden_size, self.bias))
return cells

def forward(self, x, h=None):
Expand Down
177 changes: 177 additions & 0 deletions distiller/summary_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.jit as jit
import logging
from collections import OrderedDict, defaultdict
from collections.abc import MutableSequence, Iterable
msglogger = logging.getLogger()


Expand Down Expand Up @@ -71,6 +73,11 @@ class SummaryGraph(object):
def __init__(self, model, dummy_input, apply_scope_name_workarounds=True):
self._src_model = model
model_clone = distiller.make_non_parallel_copy(model)

# Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList
# See documentation of _DistillerModuleList class for details on why this is done
model_clone, converted_module_names_map = _to_distiller_modulelist(model_clone)

with torch.onnx.set_training(model_clone, False):

device = distiller.model_device(model_clone)
Expand Down Expand Up @@ -142,6 +149,10 @@ def __init__(self, model, dummy_input, apply_scope_name_workarounds=True):

# Convert the graph node's scope name to a PyTorch module name
module_name = onnx_name_2_pytorch_name(new_op['orig-name'])

# Get name from before conversion to DistillerModuleList
module_name = converted_module_names_map[module_name]

if len(module_name) == 0:
# Special case where the module name is an empty string - this happens
# when the op is called from the "top-level" of the model
Expand Down Expand Up @@ -561,3 +572,169 @@ def __eq__(self, other):
return self.op_meta == other.op_meta and \
self.predecessors == other.predecessors and \
self.successors == other.successors


class _DistillerModuleList(object):
r"""A almost-drop-in replacement for torch.nn.ModuleList that results in full and unique scope-names when traced
So why do we need this?
Some flows in Distiller, such as modules fusion and "net-aware" quantization in PostTrainLinearQuantizer, rely
on the ability to infer the connectivity within the model, at the Python API level. This is done using
SummaryGraph, which internally uses PyTorch's trace capabilities. When tracing, each operation
executed creates a node in the trace, which has a "scope-name". Distiller then uses the "scope-name" to do a
reverse mapping - map from the trace node back to the actual nn.Module defined in the model code.
These "scope-names" are generated by tracking the ".forward()" calls of modules. However, The torch.nn.ModuleList
class itself doesn't have its own forward method. That makes perfect sense - it is only intended to be used as a
container of modules which the user accesses explicitly.
Unfortunately, this means that if an operation is part of a ModuleList, the name of the ModuleList instance
does not appear in the "scope-name". This makes it impossible for us to do the reverse mapping mentioned
above.
From here on, we refer to the module which contains the DistillerModuleList instance as the "parent module".
Similarities to torch.nn.ModuleList:
* A DistillerModuleList can be indexed like a regular Python list, but the modules it contains are properly
registered and will be visible to all torch.nn.Module methods.
* The DistllerModuleList instance is registered as an attribute of the "parent module"
* This means that in terms of accessing the modules and invoking them, DistillerModuleList behaves exactly the
same as torch.nn.ModuleList. See the example below.
Differences vs. torch.nn.ModuleList:
* DistillerModuleList is NOT a sub-class of torch.nn.Module
* This means that the modules in the list are NOT sub-modules of the list itself. They are registered as
sub-modules of the "parent module". That is - the contents of a DistillerModuleList are "flattened" within the
"parent module".
* In addition, we can't use the '.' character to denote the "nesting" of a module within the list. We use '_'.
* All of this means that calls to functions like state_dict() / named_modules() / named_children() / etc. on the
"parent_module" return different results when this class is used compared to torch.nn.ModuleList.
At the moment we don't see a usage for this class "in the wild", outside of SummaryGraph generation.
In the context of SummaryGraph, we're going to take a pre-created model and replace any torch.nn.ModuleList
instances with DistillerModuleLists. Once that is done, during model execution we expect that lists are being
used as read-only (no modules are added to/removed from the list). We're not supporting loading state_dict "across"
converted models.
This means that:
* We implement only a subset of the standard API of a Python sequence (see collections.abc.MutableSequence):
'append()', 'extend()', '__len__()' and '__getitem()_'
These are the only ones required to perform the conversion for an already created model.
* We're not implementing:
'insert()', '__setitem__()' and '__delitem__()'.
If we see in the future that our assumptions break, we'll add the necessary APIs.
For all the reasons mentioned above, and to avoid unnecessary confusion for users, we're keeping this class
internal to summary_graph for now.
Args:
name (string): The base name to be used when registering modules added to the list
parent_module (torch.nn.Module): The module to which the modules added to the list will be registered.
NOTE: This is expected to be the module containing the list, but we can't enforce this.
modules (iterable, optional): An iterable of modules to initialize the list with
"""
def __init__(self, name, parent_module, modules=None):
self.name = name
if not isinstance(parent_module, nn.Module):
raise TypeError('parent_module must be an instance of torch.nn.Module')
self.parent_module = parent_module
self._modules = []
if modules is not None:
self.extend(modules)

def _name_for_idx(self, idx):
return self.name + '_' + str(idx)

def _verify_on_insertion(self, module, idx):
if isinstance(module, nn.ModuleList):
module = _DistillerModuleList(self._name_for_idx(idx), self.parent_module, module)
if isinstance(module, _DistillerModuleList):
if module.parent_module != self.parent_module:
raise ValueError("When nesting one DistillerModuleList within another, both must have the same "
"'parent_module'")
return module

def __getitem__(self, idx):
return self._modules[idx]

def __len__(self):
return len(self._modules)

def append(self, module):
module = self._verify_on_insertion(module, len(self))
if not isinstance(module, _DistillerModuleList):
self.parent_module.add_module(self._name_for_idx(len(self)), module)
self._modules.append(module)

def extend(self, modules):
if not isinstance(modules, Iterable):
raise TypeError('DistillerModuleList.extend must be called with an iterable, but got ' +
modules.__class__.__name__)
for module in modules:
self.append(module)

def named_modules(self, memo=None, prefix=''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
# yield prefix, self
for idx, module in enumerate(self._modules):
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + str(idx)
for m in module.named_modules(memo, submodule_prefix):
yield m

def modules(self):
for _, module in self.named_modules():
yield module

def __repr__(self):
# A simplified version of torch.nn.Module.__repr__
from torch.nn.modules.module import _addindent

child_lines = []
for idx, module in enumerate(self._modules):
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + str(idx) + '): ' + mod_str)

main_str = self.__class__.__name__ + '('
if child_lines:
main_str += '\n ' + '\n '.join(child_lines) + '\n'
main_str += ')'
return main_str


def _to_distiller_modulelist(model):
"""Replaces all instances of torch.nn.ModuleList in a model with DistillerModuleList instances
Args:
model (torch.nn.Module): Model to convert
"""
def convert_container(container):
named_children = OrderedDict(container.named_children())
# To maintain a similar order of registered modules compared to the original container, we unregister
# all modules and then register them again
for n, _ in named_children.items():
delattr(container, n)
for name, child in named_children.items():
if isinstance(child, nn.ModuleList):
child = _DistillerModuleList(name, container, child)
to_check = child.modules()
else:
to_check = [child]
setattr(container, name, child)
for m in to_check:
if isinstance(m, _DistillerModuleList):
continue
if distiller.has_children(m):
convert_container(m)
return container

named_modules_orig = OrderedDict([(n, m) for n, m in model.named_modules() if not isinstance(m, nn.ModuleList)])
model = convert_container(model)
named_modules_dmlist = OrderedDict(model.named_modules())
converted_module_names_map = OrderedDict(zip(named_modules_dmlist.keys(), named_modules_orig.keys()))

return model, converted_module_names_map
Loading

0 comments on commit b614330

Please sign in to comment.