Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: torch.cuda.amp imports #5371

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions detectron2/engine/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def __init__(
)

if grad_scaler is None:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler

grad_scaler = GradScaler()
self.grad_scaler = grad_scaler
Expand All @@ -482,7 +482,7 @@ def run_step(self):
"""
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
from torch.amp import autocast

start = time.perf_counter()
data = next(self._data_loader_iter)
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_aspp(self):

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_frozen_batchnorm_fp16(self):
from torch.cuda.amp import autocast
from torch.amp import autocast

C = 10
input = torch.rand(1, C, 10, 10).cuda()
Expand Down
4 changes: 2 additions & 2 deletions tests/modeling/test_model_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_roiheads_inf_nan_data(self):

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_autocast(self):
from torch.cuda.amp import autocast
from torch.amp import autocast

inputs = [{"image": torch.rand(3, 100, 100)}]
self.model.eval()
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_inf_nan_data(self):

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_autocast(self):
from torch.cuda.amp import autocast
from torch.amp import autocast

inputs = [{"image": torch.rand(3, 100, 100)}]
self.model.eval()
Expand Down