Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node-local domain decomp #968

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 168 additions & 117 deletions mirgecom/simutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
import sys
from functools import partial
from typing import TYPE_CHECKING, Dict, List, Optional
from contextlib import contextmanager

from logpyle import IntervalTimer

import grudge.op as op
Expand Down Expand Up @@ -888,6 +890,94 @@ def generate_and_distribute_mesh(comm, generate_mesh, **kwargs):
return distribute_mesh(comm, generate_mesh)


def _partition_single_volume_mesh(
mesh, num_ranks, rank_per_element, *, return_ranks=None):
rank_to_elements = {
rank: np.where(rank_per_element == rank)[0]
for rank in range(num_ranks)}

from meshmode.mesh.processing import partition_mesh
return partition_mesh(
mesh, rank_to_elements, return_parts=return_ranks)


def _partition_multi_volume_mesh(
mesh, num_ranks, rank_per_element, tag_to_elements, volume_to_tags, *,
return_ranks=None):
if return_ranks is None:
return_ranks = list(range(num_ranks))

tag_to_volume = {
tag: vol
for vol, tags in volume_to_tags.items()
for tag in tags}

volumes = list(volume_to_tags.keys())

volume_index_per_element = np.full(mesh.nelements, -1, dtype=int)
for tag, elements in tag_to_elements.items():
volume_index_per_element[elements] = volumes.index(
tag_to_volume[tag])

if np.any(volume_index_per_element < 0):
raise ValueError("Missing volume specification for some elements.")

part_id_to_elements = {
PartID(volumes[vol_idx], rank):
np.where(
(volume_index_per_element == vol_idx)
& (rank_per_element == rank))[0]
for vol_idx in range(len(volumes))
for rank in range(num_ranks)}

# FIXME: Find a better way to do this
MTCam marked this conversation as resolved.
Show resolved Hide resolved
part_id_to_part_index = {
part_id: part_index
for part_index, part_id in enumerate(part_id_to_elements.keys())}
from meshmode.mesh.processing import _compute_global_elem_to_part_elem
global_elem_to_part_elem = _compute_global_elem_to_part_elem(
mesh.nelements, part_id_to_elements, part_id_to_part_index,
mesh.element_id_dtype)

tag_to_global_to_part = {
tag: global_elem_to_part_elem[elements, :]
for tag, elements in tag_to_elements.items()}

part_id_to_tag_to_elements = {}
for part_id in part_id_to_elements.keys():
part_idx = part_id_to_part_index[part_id]
part_tag_to_elements = {}
for tag, global_to_part in tag_to_global_to_part.items():
part_tag_to_elements[tag] = global_to_part[
global_to_part[:, 0] == part_idx, 1]
part_id_to_tag_to_elements[part_id] = part_tag_to_elements

return_parts = {
PartID(vol, rank)
for vol in volumes
for rank in return_ranks}

from meshmode.mesh.processing import partition_mesh
part_id_to_mesh = partition_mesh(
mesh, part_id_to_elements, return_parts=return_parts)

return {
rank: {
vol: (
part_id_to_mesh[PartID(vol, rank)],
part_id_to_tag_to_elements[PartID(vol, rank)])
for vol in volumes}
for rank in return_ranks}


@contextmanager
def _manage_mpi_comm(comm):
try:
yield comm
finally:
comm.Free()


def distribute_mesh(comm, get_mesh_data, partition_generator_func=None, logmgr=None):
r"""Distribute a mesh among all ranks in *comm*.

Expand Down Expand Up @@ -924,10 +1014,12 @@ def distribute_mesh(comm, get_mesh_data, partition_generator_func=None, logmgr=N
global_nelements: :class:`int`
The number of elements in the global mesh
"""
from mpi4py import MPI
from mpi4py.util import pkl5
comm_wrapper = pkl5.Intracomm(comm)
from meshmode.distributed import mpi_distribute
# pkl5_comm = pkl5.Intracomm(comm)

num_ranks = comm_wrapper.Get_size()
num_ranks = comm.Get_size()
t_mesh_dist = IntervalTimer("t_mesh_dist", "Time spent distributing mesh data.")
t_mesh_data = IntervalTimer("t_mesh_data", "Time spent getting mesh data.")
t_mesh_part = IntervalTimer("t_mesh_part", "Time spent partitioning the mesh.")
Expand All @@ -938,132 +1030,91 @@ def partition_generator_func(mesh, tag_to_elements, num_ranks):
from meshmode.distributed import get_partition_by_pymetis
return get_partition_by_pymetis(mesh, num_ranks)

if comm_wrapper.Get_rank() == 0:
if logmgr:
logmgr.add_quantity(t_mesh_data)
with t_mesh_data.get_sub_timer():
with _manage_mpi_comm(
comm.Split_type(MPI.COMM_TYPE_SHARED, comm.Get_rank(), MPI.INFO_NULL)
) as node_comm:
node_comm_wrapper = pkl5.Intracomm(node_comm)
node_ranks = node_comm_wrapper.gather(comm.Get_rank(), root=0)
my_node_rank = node_comm_wrapper.Get_rank()

if my_node_rank == 0:
if logmgr:
logmgr.add_quantity(t_mesh_data)
with t_mesh_data.get_sub_timer():
global_data = get_mesh_data()
else:
global_data = get_mesh_data()
else:
global_data = get_mesh_data()

from meshmode.mesh import Mesh
if isinstance(global_data, Mesh):
mesh = global_data
tag_to_elements = None
volume_to_tags = None
elif isinstance(global_data, tuple) and len(global_data) == 3:
mesh, tag_to_elements, volume_to_tags = global_data
else:
raise TypeError("Unexpected result from get_mesh_data")

if logmgr:
logmgr.add_quantity(t_mesh_part)
with t_mesh_part.get_sub_timer():
from meshmode.mesh import Mesh
if isinstance(global_data, Mesh):
mesh = global_data
tag_to_elements = None
volume_to_tags = None
elif isinstance(global_data, tuple) and len(global_data) == 3:
mesh, tag_to_elements, volume_to_tags = global_data
else:
raise TypeError("Unexpected result from get_mesh_data")

if logmgr:
logmgr.add_quantity(t_mesh_part)
with t_mesh_part.get_sub_timer():
rank_per_element = \
partition_generator_func(mesh, tag_to_elements,
num_ranks)
else:
rank_per_element = partition_generator_func(mesh, tag_to_elements,
num_ranks)
else:
rank_per_element = partition_generator_func(mesh, tag_to_elements,
num_ranks)

def get_rank_to_mesh_data():
from meshmode.mesh.processing import partition_mesh
if tag_to_elements is None:
rank_to_elements = {
rank: np.where(rank_per_element == rank)[0]
for rank in range(num_ranks)}

rank_to_mesh_data_dict = partition_mesh(mesh, rank_to_elements)
def get_rank_to_mesh_data():
if tag_to_elements is None:
rank_to_mesh_data = _partition_single_volume_mesh(
mesh, num_ranks, rank_per_element,
return_ranks=node_ranks)
else:
rank_to_mesh_data = _partition_multi_volume_mesh(
mesh, num_ranks, rank_per_element, tag_to_elements,
volume_to_tags, return_ranks=node_ranks)

rank_to_mesh_data = [
rank_to_mesh_data_dict[rank]
for rank in range(num_ranks)]
rank_to_node_rank = {
rank: node_rank
for node_rank, rank in enumerate(node_ranks)}

else:
tag_to_volume = {
tag: vol
for vol, tags in volume_to_tags.items()
for tag in tags}

volumes = list(volume_to_tags.keys())

volume_index_per_element = np.full(mesh.nelements, -1, dtype=int)
for tag, elements in tag_to_elements.items():
volume_index_per_element[elements] = volumes.index(
tag_to_volume[tag])

if np.any(volume_index_per_element < 0):
raise ValueError("Missing volume specification "
"for some elements.")

part_id_to_elements = {
PartID(volumes[vol_idx], rank):
np.where(
(volume_index_per_element == vol_idx)
& (rank_per_element == rank))[0]
for vol_idx in range(len(volumes))
for rank in range(num_ranks)}

# TODO: Add a public meshmode function to accomplish this? So we're
# not depending on meshmode internals
part_id_to_part_index = {
part_id: part_index
for part_index, part_id in enumerate(part_id_to_elements.keys())}
from meshmode.mesh.processing import \
_compute_global_elem_to_part_elem
global_elem_to_part_elem = _compute_global_elem_to_part_elem(
mesh.nelements, part_id_to_elements, part_id_to_part_index,
mesh.element_id_dtype)

tag_to_global_to_part = {
tag: global_elem_to_part_elem[elements, :]
for tag, elements in tag_to_elements.items()}

part_id_to_tag_to_elements = {}
for part_id in part_id_to_elements.keys():
part_idx = part_id_to_part_index[part_id]
part_tag_to_elements = {}
for tag, global_to_part in tag_to_global_to_part.items():
part_tag_to_elements[tag] = global_to_part[
global_to_part[:, 0] == part_idx, 1]
part_id_to_tag_to_elements[part_id] = part_tag_to_elements

part_id_to_mesh = partition_mesh(mesh, part_id_to_elements)

rank_to_mesh_data = [
{
vol: (
part_id_to_mesh[PartID(vol, rank)],
part_id_to_tag_to_elements[PartID(vol, rank)])
for vol in volumes}
for rank in range(num_ranks)]

return rank_to_mesh_data

if logmgr:
logmgr.add_quantity(t_mesh_split)
with t_mesh_split.get_sub_timer():
rank_to_mesh_data = get_rank_to_mesh_data()
else:
rank_to_mesh_data = get_rank_to_mesh_data()
node_rank_to_mesh_data = {
rank_to_node_rank[rank]: mesh_data
for rank, mesh_data in rank_to_mesh_data.items()}

global_nelements = comm_wrapper.bcast(mesh.nelements, root=0)
return node_rank_to_mesh_data

if logmgr:
logmgr.add_quantity(t_mesh_dist)
with t_mesh_dist.get_sub_timer():
local_mesh_data = comm_wrapper.scatter(rank_to_mesh_data, root=0)
else:
local_mesh_data = comm_wrapper.scatter(rank_to_mesh_data, root=0)
if logmgr:
logmgr.add_quantity(t_mesh_split)
with t_mesh_split.get_sub_timer():
node_rank_to_mesh_data = get_rank_to_mesh_data()
else:
node_rank_to_mesh_data = get_rank_to_mesh_data()

else:
global_nelements = comm_wrapper.bcast(None, root=0)
global_nelements = node_comm_wrapper.bcast(mesh.nelements, root=0)

if logmgr:
logmgr.add_quantity(t_mesh_dist)
with t_mesh_dist.get_sub_timer():
local_mesh_data = comm_wrapper.scatter(None, root=0)
else:
local_mesh_data = comm_wrapper.scatter(None, root=0)
if logmgr:
logmgr.add_quantity(t_mesh_dist)
with t_mesh_dist.get_sub_timer():
local_mesh_data = mpi_distribute(
node_comm_wrapper, source_rank=0,
source_data=node_rank_to_mesh_data)
else:
local_mesh_data = mpi_distribute(
node_comm_wrapper, source_rank=0,
source_data=node_rank_to_mesh_data)

else: # my_node_rank > 0, get mesh part from MPI
global_nelements = node_comm_wrapper.bcast(None, root=0)

if logmgr:
logmgr.add_quantity(t_mesh_dist)
with t_mesh_dist.get_sub_timer():
local_mesh_data = \
mpi_distribute(node_comm_wrapper, source_rank=0)
else:
local_mesh_data = mpi_distribute(node_comm_wrapper, source_rank=0)

return local_mesh_data, global_nelements

Expand Down
Loading