diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 17db16496848..b9c89b65c931 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import contextlib import itertools from typing import Callable, Dict, Iterable, Optional, Union @@ -108,6 +109,8 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam): but requires larger memory than distributing within all ranks, especially for pure data parallel models. (default: False). + lock_timeout (float, optional): timeout for callback mutex in + seconds. **kwargs: keyword arguments to pass to Apex DistributedFusedAdam. @@ -118,6 +121,7 @@ def __init__( params: Union[Iterable[torch.nn.Parameter], Iterable[dict]], disable_distributed_parameters: bool = False, distribute_within_nodes: bool = False, + lock_timeout: Optional[float] = None, **kwargs, ): @@ -152,6 +156,25 @@ def __init__( # Construct distributed optimizer super().__init__(param_groups, **kwargs) + # Create mutex with timeout + self._lock_with_timeout = None + if lock_timeout is not None: + + @contextlib.contextmanager + def lock_with_timeout(): + result = self._lock.acquire(timeout=lock_timeout) + try: + yield result + finally: + if result: + # Acquired lock before timeout + self._lock.release() + else: + # Failed to acquire lock before timeout + print(f'MegatronDistributedFusedAdam: Failed to acquire lock within {lock_timeout} seconds.') + + self._lock_with_timeout = lock_with_timeout + def _broadcast_params(self) -> None: # Assume params have already been synchronized pass @@ -166,7 +189,10 @@ def hook(*unused): 'before the forward pass (e.g. by calling data_ptr) ' 'or run DistributedFusedAdam with overlap_param_sync=False.' ) - with self._lock: + lock = self._lock + if self._lock_with_timeout is not None: + lock = self._lock_with_timeout() + with lock: need_to_initialize = 'fragments' not in self.state[param] if need_to_initialize: self._init_param_state(param, param_group_id, param_id)