Skip to content

Commit

Permalink
Merge pull request #100 from tomoki0924/fix/distribute
Browse files Browse the repository at this point in the history
explicitly choose whether or not to use torch.distribute
  • Loading branch information
Andrei Panferov authored Jul 21, 2023
2 parents 304c3c4 + 3efdf94 commit 20b1cfb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/tensor_parallel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ def __init__(
all_rules[i] = dict((re.compile(pattern), actions) for pattern, actions in rule_set.items())
self.state_rules, self.input_rules, self.output_rules, self.attr_rules = all_rules

def create_collective_ops(self, devices: Sequence[torch.device]):
def create_collective_ops(self, devices: Sequence[torch.device], distributed: bool = True):
"""
Return a copy of config with thread-parallel collective operations, such as AllGather and AllReduce
:note: this function should be called during TensorParallel init, before making shards
"""
return dataclasses.replace(
self,
input_rules=create_collective_ops(self.input_rules, devices),
output_rules=create_collective_ops(self.output_rules, devices),
input_rules=create_collective_ops(self.input_rules, devices, distributed),
output_rules=create_collective_ops(self.output_rules, devices, distributed),
)


Expand All @@ -82,13 +82,13 @@ def convert_legacy_state_action(state_action: Any) -> StateAction:
raise Exception(f"Can't convert {state_action} of type {type(state_action)} to StateAction")


def create_collective_ops(rules: dict, devices: Sequence[torch.device]):
def create_collective_ops(rules: dict, devices: Sequence[torch.device], distributed: bool = True):
"""Initialize collective thread-parallel operations from config rules"""
world_size = len(devices)
all_cuda = all(device.type == "cuda" for device in devices)
unique_output_transforms = {op for output_actions in rules.values() for op in output_actions.values()}
transform_map = {}
if torch.distributed.is_initialized():
if torch.distributed.is_initialized() and distributed:
make_allreduce, make_allgather = DistributedAllReduce, DistributedAllGather
elif all_cuda and not TENSOR_PARALLEL_USE_NATIVE:
make_allreduce, make_allgather = NCCLAllReduce, NCCLAllGather
Expand Down
12 changes: 10 additions & 2 deletions src/tensor_parallel/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,22 @@ def tensor_parallel(
else:
if isinstance(module, PreTrainedModel):
module = TensorParallelPreTrainedModel(
module, device_ids=device_ids, tensor_parallel_config=tensor_parallel_config, **kwargs
module,
device_ids=device_ids,
tensor_parallel_config=tensor_parallel_config,
distributed=distributed,
**kwargs,
)
module.wrapped_model = _maybe_sharded(
module.wrapped_model, sharded, num_trainable_parameters, sharded_param_names=sharded_param_names
)
else:
module = TensorParallel(
module, device_ids=device_ids, tensor_parallel_config=tensor_parallel_config, **kwargs
module,
device_ids=device_ids,
tensor_parallel_config=tensor_parallel_config,
distributed=distributed,
**kwargs,
)
module = _maybe_sharded(module, sharded, num_trainable_parameters, sharded_param_names=sharded_param_names)

Expand Down
3 changes: 2 additions & 1 deletion src/tensor_parallel/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
output_device: Optional[torch.device] = None,
output_device_index: Optional[int] = None,
tensor_parallel_config: Optional[Config] = None,
distributed: bool = True,
):
super().__init__(module.config) # Temporary empty config. Gets replaced in from_pretrained

Expand All @@ -55,7 +56,7 @@ def __init__(
tensor_parallel_config = find_predefined_tensor_parallel_config(module.config, device_ids)

self.wrapped_model = TensorParallel(
module, device_ids, output_device, output_device_index, tensor_parallel_config
module, device_ids, output_device, output_device_index, tensor_parallel_config, distributed=distributed
)

@property
Expand Down
3 changes: 2 additions & 1 deletion src/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
output_device_index: Optional[int] = None,
tensor_parallel_config: Optional[Config] = None,
delay_init: bool = False,
distributed: bool = True,
):
super().__init__()
original_params = sum(p.numel() for p in module.parameters())
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
tensor_parallel_config = add_lora_rules(module, tensor_parallel_config)
self.tensor_parallel_config = tensor_parallel_config

config_with_ops = tensor_parallel_config.create_collective_ops(self.devices)
config_with_ops = tensor_parallel_config.create_collective_ops(self.devices, distributed)
# ^-- creates a copy of comfig with collective op instances, such as AllReduce and AllGather

for rank, device in enumerate(self.devices):
Expand Down

0 comments on commit 20b1cfb

Please sign in to comment.