diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index cdd6e6e8c..ba5b04c4b 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -369,6 +369,7 @@ def __init__( limit_all_gather_events: bool = False, limit_reduce_scatter_events: bool = False, should_validate_process_group: bool = True, + tensor_parallel_group: Optional[ProcessGroup] = None, ): try: import torch._C @@ -380,6 +381,7 @@ def __init__( init_start = time.time() super().__init__() self.process_group = process_group or get_process_group_cached() + self.tensor_parallel_group = tensor_parallel_group # If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with # the rest of operations. The overlap feature in the backward propagation is disabled. if process_group_reduce_scatter == ProcessGroupName.default: @@ -1726,6 +1728,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if not self._require_backward_grad_sync: return + # run allreduce on param if necessary + if self.tensor_parallel_group.size() > 1: + if self.fp32_reduce_scatter: + orig_grad_data = param.unsharded_main_grad.data + else: + orig_grad_data = param.grad.data + for idx_pair in param._param_require_tp_allreduce: + param_allreduce = orig_grad_data[idx_pair[0]:idx_pair[1]].contiguous() + torch.distributed.all_reduce(param_allreduce, group=self.tensor_parallel_group) + orig_grad_data[idx_pair[0]:idx_pair[1]].copy_(param_allreduce) + # Wait for all work in the current stream to finish, then start the # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 38265dd2b..7f3573f65 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -70,6 +70,13 @@ def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True): """Initialize the _param_numels and _param_shapes lists.""" self._param_numels = [p.numel() for p in params] + self._param_require_tp_allreduce = [] + for idx in range(len(params)): + p = params[idx] + if hasattr(p, "norm_allreduce_last_microbatch") and p.norm_allreduce_last_microbatch: + self._param_require_tp_allreduce.append( + [sum(self._param_numels[0:idx]), sum(self._param_numels[0:idx+1])] + ) assert self.numel() <= sum( self._param_numels ), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"