diff --git a/torchacc/dist/distributed_parallel.py b/torchacc/dist/distributed_parallel.py index 89a4130..6b39fb4 100644 --- a/torchacc/dist/distributed_parallel.py +++ b/torchacc/dist/distributed_parallel.py @@ -14,19 +14,18 @@ class DistributedParallel(ParallelModule): def __init__(self, model: torch.nn.Module, config: Config, **kwargs): super().__init__(model, config, **kwargs) - - self.model = None + self._module = None if self.has_pp: - self.model = PipelineParallel(model, self.config, **kwargs) + self._module = PipelineParallel(model, self.config, **kwargs) fsdp_wrapper = SpmdFullyShardedDataParallel if self.spmd_fsdp else FullyShardedDataParallel if self.has_fsdp: - if self.model is None: - self.model = fsdp_wrapper(model, self.config, **kwargs) + if self._module is None: + self._module = fsdp_wrapper(model, self.config, **kwargs) else: - model = self.model._get_underlay_model() + model = self._module._get_underlay_model() model = fsdp_wrapper(model, self.config, **kwargs) - self.model._update_underlay_model(model) + self._module._update_underlay_model(model) need_wrap_dp = False if config.is_eager_backend(): @@ -35,32 +34,32 @@ def __init__(self, model: torch.nn.Module, config: Config, **kwargs): need_wrap_dp = self.has_dp and not self.has_tp if need_wrap_dp: - if self.model is None: - self.model = DataParallel(model, self.config, **kwargs) + if self._module is None: + self._module = DataParallel(model, self.config, **kwargs) else: - model = self.model._get_underlay_model() - model = DataParallel(model, self.config, **kwargs) - self.model._update_underlay_model(model) + module = self._module._get_underlay_model() + module = DataParallel(model, self.config, **kwargs) + self._module._update_underlay_model(module) - if self.model is None: - self.model = model + if self._module is None: + self._module = module def _get_underlay_model(self): - if isinstance(self.model, ParallelModule): - return self.model._get_underlay_model() - return self.model + if isinstance(self._module, ParallelModule): + return self._module._get_underlay_model() + return self._module - def _update_underlay_model(self, model: torch.nn.Module): - if isinstance(self.model, ParallelModule): - self.model._update_underlay_model(model) + def _update_underlay_model(self, module: torch.nn.Module): + if isinstance(self._module, ParallelModule): + self._module._update_underlay_model(module) else: - self.model = model + self._module = module def clip_grad_norm_(self, max_grad_norm): - if hasattr(self.model, "clip_grad_norm_"): - self.model.clip_grad_norm_(max_grad_norm) + if hasattr(self._module, "clip_grad_norm_"): + self._module.clip_grad_norm_(max_grad_norm) else: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), + torch.nn.utils.clip_grad_norm_(self._module.parameters(), max_grad_norm) def forward(self, *args, output_fn=None, **kwargs): @@ -69,11 +68,12 @@ def forward(self, *args, output_fn=None, **kwargs): "output_fn is only supported for pipeline parallel") if output_fn: kwargs["output_fn"] = output_fn - return self.model(*args, **kwargs) + return self._module(*args, **kwargs) def forward_backward(self, *args, output_fn=None, **kwargs): if not self.has_pp: raise NotImplementedError( "forward_backward is only supported for pipeline parallel.") - assert isinstance(self.model, PipelineParallel) - return self.model.forward_backward(*args, output_fn=output_fn, **kwargs) + assert isinstance(self._module, PipelineParallel) + return self._module.forward_backward( + *args, output_fn=output_fn, **kwargs) diff --git a/torchacc/dist/parallel_module.py b/torchacc/dist/parallel_module.py index caa33c5..e6cb00f 100644 --- a/torchacc/dist/parallel_module.py +++ b/torchacc/dist/parallel_module.py @@ -67,3 +67,9 @@ def forward_backward(self, *args, output_fn=None, **kwargs): """ raise NotImplementedError( "forward_backward is only supported for pipeline parallel.") + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._get_underlay_model(), name)