Skip to content

Commit

Permalink
mcr_dl_megatron changes
Browse files Browse the repository at this point in the history
Signed-off-by: Radha Guhane <[email protected]>
  • Loading branch information
RadhaGulhane13 committed Apr 20, 2024
1 parent cee0b27 commit 8b7f2f6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
17 changes: 12 additions & 5 deletions mcr_dl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,33 @@
global __dist_engine
global __dist_backend

__dist_engine = None
__dist_backend = None

def init_torch_distributed(backend):
import torch.distributed as dist
if backend == 'nccl':
mpi_discovery()
elif backend == 'mpi':
set_mpi_dist_environemnt()
dist.init_process_group(backend)
dist.init_process_group(backend=backend)
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)
# get_accelerator().set_device(local_rank)
print(f'Rank : {dist.get_rank()} World_Size : {dist.get_world_size()}', flush = True)

def init_mcr_dl_comm(backend):
import mcr_dl
mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
local_rank = int(os.environ['LOCAL_RANK'])
#get_accelerator().set_device(local_rank)

def init_processes(dist_engine, dist_backend):
def init_processes(dist_engine, dist_backend, world_size = -1, rank = -1, timeout = None, init_method = None):
print(f'Comm : {dist_engine} Backend : {dist_backend}')

global __dist_engine
global __dist_backend
__dist_engine = dist_engine
__dist_backend = dist_backend

if dist_engine == 'mcr_dl':
init_mcr_dl_comm(dist_backend)
elif dist_engine == 'torch':
Expand All @@ -56,8 +59,12 @@ def init_processes(dist_engine, dist_backend):

def get_distributed_engine():
global __dist_engine
if __dist_engine is None:
return None
if __dist_engine == 'torch':
return torch.distributed
elif __dist_engine == 'mcr_dl':
import mcr_dl
return mcr_dl
return mcr_dl
print(f"Unsupported values for __dist_engine. Expected values 'torch' or 'mcr_dl'")
exit(0)
5 changes: 3 additions & 2 deletions mcr_dl/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ def destroy_process_group(self, group=None):
pass

def new_group(self, ranks):
# TODO: Change this to use comm_op.new_group when the impl. is ready.
# TODO: Change this to use self.mpi_comm_op.new_group(ranks) when the impl. is ready.
if not torch.distributed.is_initialized():
from mcr_dl.torch import TorchBackend
d = TorchBackend(rank=self.rank, size=self.size)
d = TorchBackend(rank=self.rank, world_size=self.size)
logger.info(f"new group called with {ranks}")
return torch.distributed.new_group(ranks)
# return self.mpi_comm_op.new_group(ranks)

def get_rank(self, group=None):
return self.mpi_comm_op.get_rank(0)
Expand Down
5 changes: 3 additions & 2 deletions mcr_dl/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .utils import *
from .backend import *
from .comm import *
from .constants import default_pg_timeout

DS_COMM_ALL_GATHER_OFF = False
DS_COMM_REDUCE_SCATTER_OFF = False
Expand Down Expand Up @@ -119,7 +120,7 @@ class TorchBackend(Backend):
needed.
"""

def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
def __init__(self, backend="mpi", init_method = None, timeout = default_pg_timeout, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
self.has_coalescing_manager = has_coalescing_manager()
Expand All @@ -131,7 +132,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
# The idea is to fake that dist backend is initialized even when
# it is not so we can run on a single GPU without doing any init_process_group
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method, rank, world_size)
self.init_process_group(backend=backend, init_method=init_method, timeout= timeout, rank=rank, world_size= world_size)

@classmethod
def get_all_gather_function(self):
Expand Down

0 comments on commit 8b7f2f6

Please sign in to comment.