Skip to content

Commit

Permalink
example: fix cuda error in pytorch imagenet (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymjiang authored Feb 13, 2020
1 parent 7e688ff commit af6fd58
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion example/pytorch/train_imagenet_resnet50_byteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def adjust_learning_rate(epoch, batch_idx):
def accuracy(output, target):
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
return pred.eq(target.view_as(pred)).cpu().float().mean()
return pred.eq(target.view_as(pred)).float().mean()


def save_checkpoint(epoch):
Expand Down

0 comments on commit af6fd58

Please sign in to comment.