From af6fd58fcf5a74f7c4928bf1184990477f126b03 Mon Sep 17 00:00:00 2001 From: Yimin Jiang Date: Thu, 13 Feb 2020 12:28:59 +0800 Subject: [PATCH] example: fix cuda error in pytorch imagenet (#204) --- example/pytorch/train_imagenet_resnet50_byteps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/pytorch/train_imagenet_resnet50_byteps.py b/example/pytorch/train_imagenet_resnet50_byteps.py index 360a36fbc..274c9bf5c 100644 --- a/example/pytorch/train_imagenet_resnet50_byteps.py +++ b/example/pytorch/train_imagenet_resnet50_byteps.py @@ -239,7 +239,7 @@ def adjust_learning_rate(epoch, batch_idx): def accuracy(output, target): # get the index of the max log-probability pred = output.max(1, keepdim=True)[1] - return pred.eq(target.view_as(pred)).cpu().float().mean() + return pred.eq(target.view_as(pred)).float().mean() def save_checkpoint(epoch):