Skip to content

Commit

Permalink
Merge pull request #410 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Simplify (D)DP wrapper init
  • Loading branch information
yoshitomo-matsubara authored Oct 31, 2023
2 parents cb8ad15 + 14a9ec7 commit e768a70
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
6 changes: 2 additions & 4 deletions torchdistill/core/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,14 @@ def setup(self, train_config):
freeze_module_params(self.student_model)

# Wrap models if necessary
teacher_unused_parameters = teacher_config.get('find_unused_parameters', self.teacher_any_frozen)
teacher_any_updatable = len(get_updatable_param_names(self.teacher_model)) > 0
self.teacher_model =\
wrap_model(self.teacher_model, teacher_config, self.device, self.device_ids, self.distributed,
teacher_unused_parameters, teacher_any_updatable)
student_unused_parameters = student_config.get('find_unused_parameters', self.student_any_frozen)
self.teacher_any_frozen, teacher_any_updatable)
student_any_updatable = len(get_updatable_param_names(self.student_model)) > 0
self.student_model =\
wrap_model(self.student_model, student_config, self.device, self.device_ids, self.distributed,
student_unused_parameters, student_any_updatable)
self.student_any_frozen, student_any_updatable)

# Set up optimizer and scheduler
optim_config = train_config.get('optimizer', dict())
Expand Down
3 changes: 1 addition & 2 deletions torchdistill/core/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,9 @@ def setup(self, train_config):

# Wrap models if necessary
any_updatable = len(get_updatable_param_names(self.model)) > 0
model_unused_parameters = model_config.get('find_unused_parameters', self.model_any_frozen)
self.model =\
wrap_model(self.model, model_config, self.device, self.device_ids, self.distributed,
model_unused_parameters, any_updatable)
self.model_any_frozen, any_updatable)

# Set up optimizer and scheduler
optim_config = train_config.get('optimizer', dict())
Expand Down
21 changes: 16 additions & 5 deletions torchdistill/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,23 @@ def wrap_model(model, model_config, device, device_ids=None, distributed=False,
:rtype: nn.Module
"""
wrapper = model_config.get('wrapper', None) if model_config is not None else None
wrapper_kwargs = dict()
if isinstance(wrapper, dict):
wrapper_key = wrapper.get('key', None)
wrapper_kwargs = wrapper.get('kwargs', wrapper_kwargs)
else:
wrapper_key = wrapper

wrapper_kwargs['device_ids'] = device_ids
model.to(device)
if wrapper is not None and device.type.startswith('cuda') and not check_if_wrapped(model):
if wrapper == 'DistributedDataParallel' and distributed and any_updatable:
model = DistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=find_unused_parameters)
elif wrapper in {'DataParallel', 'DistributedDataParallel'}:
model = DataParallel(model, device_ids=device_ids)
if wrapper_key is not None and device.type.startswith('cuda') and not check_if_wrapped(model):
if wrapper_key == 'DistributedDataParallel' and distributed and any_updatable:
if 'find_unused_parameters' not in wrapper_kwargs:
wrapper_kwargs['find_unused_parameters'] = find_unused_parameters

model = DistributedDataParallel(model, **wrapper_kwargs)
elif wrapper_key in {'DataParallel', 'DistributedDataParallel'}:
model = DataParallel(model, **wrapper_kwargs)
return model


Expand Down
5 changes: 3 additions & 2 deletions torchdistill/models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger = def_logger.getChild(__name__)


def wrap_if_distributed(module, device, device_ids, distributed, find_unused_parameters=None):
def wrap_if_distributed(module, device, device_ids, distributed, find_unused_parameters=None, **kwargs):
"""
Wraps ``module`` with DistributedDataParallel if ``distributed`` = True and ``module`` has any updatable parameters.
Expand All @@ -37,7 +37,8 @@ def wrap_if_distributed(module, device, device_ids, distributed, find_unused_par
any_frozen = len(get_frozen_param_names(module)) > 0
if find_unused_parameters is None:
find_unused_parameters = any_frozen
return DistributedDataParallel(module, device_ids=device_ids, find_unused_parameters=find_unused_parameters)
return DistributedDataParallel(module, device_ids=device_ids, find_unused_parameters=find_unused_parameters,
**kwargs)
return module


Expand Down

0 comments on commit e768a70

Please sign in to comment.