-
Notifications
You must be signed in to change notification settings - Fork 373
/
amsoftmax.py
52 lines (42 loc) · 1.56 KB
/
amsoftmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#! /usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
class AMSoftmax(nn.Module):
def __init__(self,
in_feats,
n_classes=10,
m=0.3,
s=15):
super(AMSoftmax, self).__init__()
self.m = m
self.s = s
self.in_feats = in_feats
self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), requires_grad=True)
self.ce = nn.CrossEntropyLoss()
nn.init.xavier_normal_(self.W, gain=1)
def forward(self, x, lb):
assert x.size()[0] == lb.size()[0]
assert x.size()[1] == self.in_feats
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
x_norm = torch.div(x, x_norm)
w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
w_norm = torch.div(self.W, w_norm)
costh = torch.mm(x_norm, w_norm)
lb_view = lb.view(-1, 1)
if lb_view.is_cuda: lb_view = lb_view.cpu()
delt_costh = torch.zeros(costh.size()).scatter_(1, lb_view, self.m)
if x.is_cuda: delt_costh = delt_costh.cuda()
costh_m = costh - delt_costh
costh_m_s = self.s * costh_m
loss = self.ce(costh_m_s, lb)
return loss
if __name__ == '__main__':
criteria = AMSoftmax(1024, 10)
a = torch.randn(20, 1024)
lb = torch.randint(0, 10, (20, ), dtype=torch.long)
loss = criteria(a, lb)
loss.backward()
print(loss.detach().numpy())
print(list(criteria.parameters())[0].shape)
print(type(next(criteria.parameters())))