-
Notifications
You must be signed in to change notification settings - Fork 227
/
focalloss_test.py
45 lines (36 loc) · 1.11 KB
/
focalloss_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import os,sys,random,time
import argparse
from focalloss import *
start_time = time.time()
maxe = 0
for i in range(1000):
x = torch.rand(12800,2)*random.randint(1,10)
x = Variable(x.cuda())
l = torch.rand(12800).ge(0.1).long()
l = Variable(l.cuda())
output0 = FocalLoss(gamma=0)(x,l)
output1 = nn.CrossEntropyLoss()(x,l)
a = output0.data[0]
b = output1.data[0]
if abs(a-b)>maxe: maxe = abs(a-b)
print('time:',time.time()-start_time,'max_error:',maxe)
start_time = time.time()
maxe = 0
for i in range(100):
x = torch.rand(128,1000,8,4)*random.randint(1,10)
x = Variable(x.cuda())
l = torch.rand(128,8,4)*1000 # 1000 is classes_num
l = l.long()
l = Variable(l.cuda())
output0 = FocalLoss(gamma=0)(x,l)
output1 = nn.NLLLoss2d()(F.log_softmax(x),l)
a = output0.data[0]
b = output1.data[0]
if abs(a-b)>maxe: maxe = abs(a-b)
print('time:',time.time()-start_time,'max_error:',maxe)