diff --git a/metaseq/distributed/rendezvous.py b/metaseq/distributed/rendezvous.py index 08379b92d..9d54ef346 100644 --- a/metaseq/distributed/rendezvous.py +++ b/metaseq/distributed/rendezvous.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse from torch.distributed.constants import default_pg_timeout -from torch.distributed import register_rendezvous_handler, Store, TCPStore +from torch.distributed import register_rendezvous_handler, Store, TCPStore, rendezvous RETRIES = 5 @@ -70,4 +70,59 @@ def _error(msg): raise RuntimeError("Unable to perform re-rendezvous using tcpr:// method") +STORE_BASED_BARRIER_MARKER = "store_based_barrier_key:" + + +class ComplicitStore(Store): + def __init__(self, store: Store, world_size: int): + super().__init__() + self.store = store + self.world_size = world_size + + def set(self, *args, **kwargs): + return self.store.set(*args, **kwargs) + + def get(self, *args, **kwargs): + return self.store.get(*args, **kwargs) + + def add(self, key: str, *args, **kwargs): + # The marker isn't always a prefix: it's sometimes prefixed with default_pg/. + if STORE_BASED_BARRIER_MARKER in key: + return self.world_size + return self.store.add(key, *args, **kwargs) + + def compare_set(self, *args, **kwargs): + return self.store.compare_set(*args, **kwargs) + + def delete_key(self, *args, **kwargs): + return self.store.delete_key(*args, **kwargs) + + def num_keys(self, *args, **kwargs): + return self.store.num_keys(*args, **kwargs) + + def set_timeout(self, *args, **kwargs): + return self.store.set_timeout(*args, **kwargs) + + def wait(self, *args, **kwargs): + return self.store.wait(*args, **kwargs) + + +STORE = None + + +def _tcp_retry_barrierless_rendezvous_handler( + url: str, timeout=default_pg_timeout, **kwargs +): + assert url.startswith("barrierlesstcpr://") + url = url.replace("barrierlesstcpr://", "tcpr://") + rendezvous_iterator = rendezvous(url, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + global STORE + STORE = ComplicitStore(store, world_size) + return iter([(STORE, rank, world_size)]) + + register_rendezvous_handler("tcpr", _tcp_retry_rendezvous_handler) +register_rendezvous_handler( + "barrierlesstcpr", _tcp_retry_barrierless_rendezvous_handler +) diff --git a/metaseq/distributed/utils.py b/metaseq/distributed/utils.py index 05dd769da..7310ebe47 100644 --- a/metaseq/distributed/utils.py +++ b/metaseq/distributed/utils.py @@ -84,7 +84,7 @@ def _infer_slurm_init(cfg: DistributedTrainingConfig): host = os.environ.get("MASTER_ADDR", None) if host is None: return - cfg.distributed_init_method = "tcpr://{host}:{port}".format( + cfg.distributed_init_method = "barrierlesstcpr://{host}:{port}".format( host=host, port=cfg.distributed_port ) nnodes = int(os.environ.get("SLURM_NNODES"))