diff --git a/CHANGELOG.md b/CHANGELOG.md index 39af7586ff..d3f5bded48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 spectrum. - Added gradient clipping to StaticCapture utilities. - Bistride Multiscale MeshGraphNet example. +- Support for MPICH-based MPI launching. ### Changed diff --git a/modulus/distributed/manager.py b/modulus/distributed/manager.py index 61bc2687e9..3af4380922 100644 --- a/modulus/distributed/manager.py +++ b/modulus/distributed/manager.py @@ -287,6 +287,39 @@ def initialize_open_mpi(addr, port): method="openmpi", ) + @staticmethod + def initialize_mpich(addr, port): + """Setup method using MPICH initialization""" + rank = int(os.environ.get("PMI_RANK")) + world_size = int(os.environ.get("PMI_SIZE")) + + # cray-mpich + if "PMI_LOCAL_RANK" in os.environ: + local_rank = int(os.environ.get("PMI_LOCAL_RANK")) + # mpich-4.2.1 / hydra + else: + local_rank = int(os.environ.get("MPI_LOCALRANKID")) + + # for multi-node MPI jobs, determine "addr" as the + # address of global rank 0. + if "localhost" == addr: + try: + import socket + from mpi4py import MPI + comm = MPI.COMM_WORLD + addr = comm.bcast(socket.gethostbyname(socket.gethostname()), root=0) + except ImportError: pass + + DistributedManager.setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=DistributedManager.get_available_backend(), + method="mpich", + ) + @staticmethod def initialize_slurm(port): """Setup method using SLURM initialization""" @@ -319,6 +352,9 @@ def initialize(): `OPENMPI`: Initialization for OpenMPI launchers. Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and `OMPI_COMM_WORLD_LOCAL_RANK` environment variables. + `MPICH`: Initialization for MPICH-based MPI launchers. + Uses `PMI_RANK`, `PMI_SIZE` and + either `PMI_LOCAL_RANK` or `MPI_LOCALRANKID` environment variables. Initialization by default is done using the first valid method in the order listed above. Initialization method can also be explicitly controlled using the @@ -342,9 +378,11 @@ def initialize(): DistributedManager.initialize_slurm(port) elif "OMPI_COMM_WORLD_RANK" in os.environ: DistributedManager.initialize_open_mpi(addr, port) + elif "PMI_RANK" in os.environ: + DistributedManager.initialize_mpich(addr, port) else: warn( - "Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job" + "Could not initialize using ENV, SLURM, OPENMPI or MPICH methods. Assuming this is a single process job" ) DistributedManager._shared_state["_is_initialized"] = True elif initialization_method == "ENV": @@ -353,13 +391,15 @@ def initialize(): DistributedManager.initialize_slurm(port) elif initialization_method == "OPENMPI": DistributedManager.initialize_open_mpi(addr, port) + elif initialization_method == "MPICH": + DistributedManager.initialize_mpich(addr, port) else: raise RuntimeError( "Unknown initialization method " f"{initialization_method}. " "Supported values for " "MODULUS_DISTRIBUTED_INITIALIZATION_METHOD are " - "ENV, SLURM and OPENMPI" + "ENV, SLURM, OPENMPI, and MPICH" ) # Set per rank numpy random seed for data sampling