diff --git a/apex/contrib/cudnn_gbn/batch_norm.py b/apex/contrib/cudnn_gbn/batch_norm.py index d5f1fabd8..0a2669b15 100644 --- a/apex/contrib/cudnn_gbn/batch_norm.py +++ b/apex/contrib/cudnn_gbn/batch_norm.py @@ -4,12 +4,12 @@ from torch import Tensor import peer_memory_cuda as pm import cudnn_gbn_lib -from torch.cuda.amp import custom_fwd, custom_bwd +from torch.amp import custom_fwd, custom_bwd class _GroupBatchNorm2d(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type='cuda') def forward(ctx, input, weight, bias, running_mean, running_variance, minibatch_mean, minibatch_inv_var, momentum, eps, group_size, group_rank, fwd_buffers, bwd_buffers): ctx.save_for_backward(input, weight, minibatch_mean, minibatch_inv_var) @@ -21,7 +21,7 @@ def forward(ctx, input, weight, bias, running_mean, running_variance, minibatch_mean, minibatch_inv_var, momentum, eps, group_size, group_rank, fwd_buffers) @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, grad_output): x, scale, minibatch_mean, minibatch_inv_var = ctx.saved_variables eps = ctx.eps