From 64aea56497b49f5a510b922f8bf878bf8e6bf8fa Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Wed, 8 May 2024 14:50:01 -0700 Subject: [PATCH 1/2] Avoid importing apex transformer automatically and make error messages more clear when apex.transformer is explicitly called on unsupported platform --- apex/__init__.py | 1 - apex/transformer/utils.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/apex/__init__.py b/apex/__init__.py index 74851f5b3..be739694b 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -24,7 +24,6 @@ # load time) the error message is timely and visible. from . import optimizers from . import normalization -from . import transformer # Logging utilities for apex.transformer module diff --git a/apex/transformer/utils.py b/apex/transformer/utils.py index 4434e3604..0991bd862 100644 --- a/apex/transformer/utils.py +++ b/apex/transformer/utils.py @@ -8,6 +8,7 @@ # The following 4 lines are for backward comparability with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): + assert torch.distributed.is_available(), "PyTorch Distributed is Not available or Disabled." torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base def ensure_divisibility(numerator, denominator): From bfeb06aa05835fdbd328d29aa5aa3f69b432e693 Mon Sep 17 00:00:00 2001 From: Wei Wang <143543872+nWEIdia@users.noreply.github.com> Date: Wed, 8 May 2024 18:38:49 -0700 Subject: [PATCH 2/2] Update apex/transformer/utils.py Co-authored-by: Masaki Kozuki --- apex/transformer/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apex/transformer/utils.py b/apex/transformer/utils.py index 0991bd862..39d5d7668 100644 --- a/apex/transformer/utils.py +++ b/apex/transformer/utils.py @@ -8,7 +8,8 @@ # The following 4 lines are for backward comparability with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): - assert torch.distributed.is_available(), "PyTorch Distributed is Not available or Disabled." + if not torch.distributed.is_available(): + raise RuntimeError("PyTorch Distributed is Not available or Disabled.") torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base def ensure_divisibility(numerator, denominator):