diff --git a/catalyst/__version__.py b/catalyst/__version__.py index 8569b7a796..6c5f370d29 100644 --- a/catalyst/__version__.py +++ b/catalyst/__version__.py @@ -1 +1 @@ -__version__ = "21.09rc1" +__version__ = "21.09" diff --git a/catalyst/engines/apex.py b/catalyst/engines/apex.py index 8dd2d628c4..43e58aaa67 100644 --- a/catalyst/engines/apex.py +++ b/catalyst/engines/apex.py @@ -370,7 +370,7 @@ class DistributedDataParallelAPEXEngine(DistributedDataParallelEngine): address: address to use for backend. port: port to use for backend. sync_bn: boolean flag for batchnorm synchonization during disributed training. - if True, applies PyTorch `convert_sync_batchnorm`_ to the model for native torch + if True, applies Apex `convert_syncbn_model`_ to the model for native torch distributed only. Default, False. ddp_kwargs: parameters for `apex.parallel.DistributedDataParallel`. More info here: @@ -439,9 +439,8 @@ def get_engine(self): stages: ... - .. _convert_sync_batchnorm: - https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html# - torch.nn.SyncBatchNorm.convert_sync_batchnorm + .. _`convert_syncbn_model`: + https://nvidia.github.io/apex/parallel.html#apex.parallel.convert_syncbn_model """ def __init__( @@ -501,7 +500,7 @@ def init_components( model = model_fn() model = self.sync_device(model) if self._sync_bn: - model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = apex.parallel.convert_syncbn_model(model) criterion = criterion_fn() criterion = self.sync_device(criterion) diff --git a/catalyst/engines/fairscale.py b/catalyst/engines/fairscale.py index d03f60e95c..b690125396 100644 --- a/catalyst/engines/fairscale.py +++ b/catalyst/engines/fairscale.py @@ -340,6 +340,7 @@ def __init__( self, address: str = None, port: Union[str, int] = None, + sync_bn: bool = False, ddp_kwargs: Dict[str, Any] = None, process_group_kwargs: Dict[str, Any] = None, scaler_kwargs: Dict[str, Any] = None, @@ -348,6 +349,7 @@ def __init__( super().__init__( address=address, port=port, + sync_bn=sync_bn, ddp_kwargs=ddp_kwargs, process_group_kwargs=process_group_kwargs, )