Skip to content

Commit

Permalink
Optmize recv_calls
Browse files Browse the repository at this point in the history
  • Loading branch information
wangvsa committed Aug 24, 2021
1 parent 76ae10d commit e7fae79
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
8 changes: 5 additions & 3 deletions tools/verifyio/gen_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import repeat
import match_mpi

ANY_SOURCE = -2
ANY_TAG = -1
Expand Down Expand Up @@ -43,7 +44,7 @@ class VerifyIOContext:
def __init__(self, reader, mpi_sync_calls):
self.num_ranks = reader.GM.total_ranks
self.all_calls = [[] for i in repeat(None, self.num_ranks)]
self.recv_calls = [[] for i in repeat(None, self.num_ranks)]
self.recv_calls = [[[] for i in repeat(None, self.num_ranks)] for j in repeat(None, self.num_ranks)]
self.send_calls = [0 for i in repeat(None, self.num_ranks)]
self.wait_test_calls = [[] for i in repeat(None, self.num_ranks)]
self.coll_calls = [{} for i in repeat(None, self.num_ranks)]
Expand Down Expand Up @@ -82,7 +83,7 @@ def is_all_to_all_call(self, func_name):
return True
return False

def generate_mpi_nodes(self, reader):
def generate_mpi_nodes(self, reader, translate):
def mpi_status_to_src_tag(status_str):
if status_str.startswith("["):
return status_str[1:-1].split("_")[0], status_str[1:-1].split("_")[1]
Expand Down Expand Up @@ -222,7 +223,8 @@ def mpi_status_to_src_tag(status_str):
if self.is_send_call(call):
self.send_calls[rank] += 1
if self.is_recv_call(call):
self.recv_calls[rank].append(idx)
global_src = match_mpi.local2global(translate, comm, int(src))
self.recv_calls[rank][global_src].append(idx)
if call.startswith("MPI_Wait") or call.startswith("MPI_Test"):
self.wait_test_calls[rank].append(idx)

Expand Down
24 changes: 10 additions & 14 deletions tools/verifyio/match_mpi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
from gen_nodes import VerifyIOContext
import gen_nodes

edges = []

Expand Down Expand Up @@ -156,17 +155,15 @@ def match_pt2pt(send_call, context, translate):

comm = send_call.comm
global_dst = local2global(translate, comm, send_call.dst)
global_src = send_call.rank

for recv_call_idx in context.recv_calls[global_dst]:
for recv_call_idx in context.recv_calls[global_dst][global_src]:
recv_call = context.all_calls[global_dst][recv_call_idx]

# Check for comm, src, and tag.
if recv_call.comm != comm: continue

global_src = local2global(translate, comm, recv_call.src)

if (global_src == send_call.rank or global_src == gen_nodes.ANY_SOURCE) and \
(recv_call.rtag == send_call.stag or recv_call.rtag == gen_nodes.ANY_TAG):
if (recv_call.rtag == send_call.stag or recv_call.rtag == gen_nodes.ANY_TAG):

if recv_call.is_blocking_call():
t = (recv_call.rank, recv_call.index, recv_call.func, recv_call.tend)
Expand All @@ -179,7 +176,7 @@ def match_pt2pt(send_call, context, translate):
t = (wait_call.rank, wait_call.index, wait_call.func, wait_call.tend)

if t:
context.recv_calls[global_dst].remove(recv_call_idx)
context.recv_calls[global_dst][global_src].remove(recv_call_idx)
break

if t == None:
Expand All @@ -198,8 +195,7 @@ def match_mpi_calls(reader, mpi_sync_calls=False):
translate = get_translation_table(reader)

context = VerifyIOContext(reader, mpi_sync_calls)
context.generate_mpi_nodes(reader)

context.generate_mpi_nodes(reader, translate)

for rank in range(context.num_ranks):
print("Rank: %d, recv calls: %d, send calls: %d" %(rank, len(context.recv_calls[rank]), context.send_calls[rank]))
Expand All @@ -215,11 +211,11 @@ def match_mpi_calls(reader, mpi_sync_calls=False):

# validate result
for rank in range(context.num_ranks):
if len(context.recv_calls[rank]) != 0:
print("Rank %d still has unmatched recvs: %d" %(rank, len(context.recv_calls[rank])))
for idx in context.recv_calls[rank]:
recv_call = context.all_calls[rank][idx]
#print(recv_call.index, recv_call.func, recv_call.src, recv_call.rtag)
recvs_sum = 0
for i in range(context.num_ranks):
recvs_sum += len(context.recv_calls[rank][i])
if recvs_sum:
print("Rank %d still has unmatched recvs: %d" %(rank, recvs_sum))
if len(context.coll_calls[rank]) != 0:
print("Rank %d still has unmatched colls: %d" %(rank, len(context.coll_calls[rank])))
if len(context.wait_test_calls[rank]) != 0:
Expand Down

0 comments on commit e7fae79

Please sign in to comment.