Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed Jul 2, 2024
1 parent beefa30 commit 3f504b3
Showing 1 changed file with 1 addition and 29 deletions.
30 changes: 1 addition & 29 deletions src/nanotron/parallel/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,12 @@ def _init_parallel_groups(self):

self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3, 4)).reshape((-1, self.tensor_parallel_size)))
self.sp_pg = self.create_new_group(ranks.transpose((1, 2, 3, 4, 0)).reshape((-1, self.sequence_parallel_size)))
# only need sync the gradient. Don't need to load different data. Be careful!
# Anything related to dp_pg get changed.
# Create a group DP+CP. Sync between this group.
# temporary fix
# self.dp_pg = self.sp_pg
# Create a group DP+SP. Sync gradient/avg loss/shard optimizer between this group. things related to dp_pg get changed. But don't need to load different data.
self.dp_sp_pg = self.create_new_group(
ranks.transpose((1, 4, 2, 0, 3)).reshape((-1, self.data_parallel_size * self.sequence_parallel_size))
) # the last two dimension should correspond to sp and dp.
# self.dp_pg = self.create_new_group(ranks.transpose((1, 4, 2, 0, 3)).reshape((-1, self.data_parallel_size*self.sequence_parallel_size))) #the last two dimension should correspond to sp and dp.
self.dp_pg = self.create_new_group(ranks.transpose((4, 0, 1, 2, 3)).reshape((-1, self.data_parallel_size)))
self.pp_pg = self.create_new_group(ranks.transpose((3, 4, 0, 1, 2)).reshape((-1, self.pipeline_parallel_size)))
# print("dp_sp_pg: ",ranks.transpose((1, 4, 2, 0, 3)).reshape((-1, self.data_parallel_size*self.sequence_parallel_size)))
# print("dp_pg: ", ranks.transpose((4, 0, 1, 2, 3)).reshape((-1, self.data_parallel_size)))
# print("sp_pg: ", ranks.transpose((1, 2, 3, 4, 0)).reshape((-1, self.sequence_parallel_size)))
self.expert_pg = self.create_new_group(
ranks.transpose((2, 3, 4, 0, 1)).reshape((-1, self.expert_parallel_size))
)
Expand All @@ -113,25 +105,6 @@ def _init_parallel_groups(self):
for sp_rank in range(self.sequence_parallel_size)
]
)

# self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3)).reshape((-1, self.tensor_parallel_size)))
# self.dp_pg = self.create_new_group(ranks.transpose((3, 0, 1, 2)).reshape((-1, self.data_parallel_size)))
# self.pp_pg = self.create_new_group(ranks.transpose((2, 3, 0, 1)).reshape((-1, self.pipeline_parallel_size)))
# self.expert_pg = self.create_new_group(ranks.transpose((1, 2, 3, 0)).reshape((-1, self.expert_parallel_size)))

# model parallel group = combination of tp and pp and exp for a given dp rank
# self.mp_pg = self.create_new_group(
# [ranks[:, :, dp_rank, :].reshape(-1) for dp_rank in range(self.data_parallel_size)]
# )

# self.tp_and_expert_pg = self.create_new_group(
# [
# ranks[:, pp_rank, dp_rank, :].reshape(-1)
# for pp_rank in range(self.pipeline_parallel_size)
# for dp_rank in range(self.data_parallel_size)
# ]
# )

self.world_rank_matrix: np.ndarray = ranks

def create_new_group(self, all_groups_ranks: np.ndarray) -> dist.ProcessGroup:
Expand Down Expand Up @@ -190,5 +163,4 @@ def get_global_rank(
:return: numpy.int64, The global rank.
"""
# return self.world_rank_matrix[sp_rank, ep_rank, pp_rank, 0, tp_rank]
return self.world_rank_matrix[sp_rank, ep_rank, pp_rank, dp_rank, tp_rank]

0 comments on commit 3f504b3

Please sign in to comment.