diff --git a/example/pytorch/train_imagenet_resnet50_byteps.py b/example/pytorch/train_imagenet_resnet50_byteps.py index 360a36fbc..274c9bf5c 100644 --- a/example/pytorch/train_imagenet_resnet50_byteps.py +++ b/example/pytorch/train_imagenet_resnet50_byteps.py @@ -239,7 +239,7 @@ def adjust_learning_rate(epoch, batch_idx): def accuracy(output, target): # get the index of the max log-probability pred = output.max(1, keepdim=True)[1] - return pred.eq(target.view_as(pred)).cpu().float().mean() + return pred.eq(target.view_as(pred)).float().mean() def save_checkpoint(epoch):