Skip to content

Commit

Permalink
Add option for mutex timeout in distributed optimizer backward hook (#…
Browse files Browse the repository at this point in the history
…9087)

* Tim: Add option for timeout in distopt callback mutex

Signed-off-by: Jaemin Choi <[email protected]>

* Replace parent's _lock

Signed-off-by: Jaemin Choi <[email protected]>

* Revert "Replace parent's _lock"

This reverts commit 972d1b6.

Signed-off-by: Jaemin Choi <[email protected]>

* Raise RuntimeError when timeout

Signed-off-by: Jaemin Choi <[email protected]>

* Change RuntimeError to print

Signed-off-by: Jaemin Choi <[email protected]>

---------

Signed-off-by: Jaemin Choi <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
  • Loading branch information
minitu and Jaemin Choi authored May 2, 2024
1 parent 7135609 commit b2eccd2
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import contextlib
import itertools
from typing import Callable, Dict, Iterable, Optional, Union

Expand Down Expand Up @@ -108,6 +109,8 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam):
but requires larger memory than distributing within all
ranks, especially for pure data parallel models.
(default: False).
lock_timeout (float, optional): timeout for callback mutex in
seconds.
**kwargs: keyword arguments to pass to Apex
DistributedFusedAdam.
Expand All @@ -118,6 +121,7 @@ def __init__(
params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],
disable_distributed_parameters: bool = False,
distribute_within_nodes: bool = False,
lock_timeout: Optional[float] = None,
**kwargs,
):

Expand Down Expand Up @@ -152,6 +156,25 @@ def __init__(
# Construct distributed optimizer
super().__init__(param_groups, **kwargs)

# Create mutex with timeout
self._lock_with_timeout = None
if lock_timeout is not None:

@contextlib.contextmanager
def lock_with_timeout():
result = self._lock.acquire(timeout=lock_timeout)
try:
yield result
finally:
if result:
# Acquired lock before timeout
self._lock.release()
else:
# Failed to acquire lock before timeout
print(f'MegatronDistributedFusedAdam: Failed to acquire lock within {lock_timeout} seconds.')

self._lock_with_timeout = lock_with_timeout

def _broadcast_params(self) -> None:
# Assume params have already been synchronized
pass
Expand All @@ -166,7 +189,10 @@ def hook(*unused):
'before the forward pass (e.g. by calling data_ptr) '
'or run DistributedFusedAdam with overlap_param_sync=False.'
)
with self._lock:
lock = self._lock
if self._lock_with_timeout is not None:
lock = self._lock_with_timeout()
with lock:
need_to_initialize = 'fragments' not in self.state[param]
if need_to_initialize:
self._init_param_state(param, param_group_id, param_id)
Expand Down

0 comments on commit b2eccd2

Please sign in to comment.