Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fairscale support for only performing allreduce in last microbatch #1168

Draft
wants to merge 1 commit into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def __init__(
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
should_validate_process_group: bool = True,
tensor_parallel_group: Optional[ProcessGroup] = None,
):
try:
import torch._C
Expand All @@ -380,6 +381,7 @@ def __init__(
init_start = time.time()
super().__init__()
self.process_group = process_group or get_process_group_cached()
self.tensor_parallel_group = tensor_parallel_group
# If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with
# the rest of operations. The overlap feature in the backward propagation is disabled.
if process_group_reduce_scatter == ProcessGroupName.default:
Expand Down Expand Up @@ -1726,6 +1728,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if not self._require_backward_grad_sync:
return

# run allreduce on param if necessary
if self.tensor_parallel_group.size() > 1:
if self.fp32_reduce_scatter:
orig_grad_data = param.unsharded_main_grad.data
else:
orig_grad_data = param.grad.data
for idx_pair in param._param_require_tp_allreduce:
param_allreduce = orig_grad_data[idx_pair[0]:idx_pair[1]].contiguous()
torch.distributed.all_reduce(param_allreduce, group=self.tensor_parallel_group)
orig_grad_data[idx_pair[0]:idx_pair[1]].copy_(param_allreduce)

# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
Expand Down
7 changes: 7 additions & 0 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) ->
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
self._param_require_tp_allreduce = []
for idx in range(len(params)):
p = params[idx]
if hasattr(p, "norm_allreduce_last_microbatch") and p.norm_allreduce_last_microbatch:
self._param_require_tp_allreduce.append(
[sum(self._param_numels[0:idx]), sum(self._param_numels[0:idx+1])]
)
assert self.numel() <= sum(
self._param_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
Expand Down
Loading