-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathim4MEC.py
108 lines (85 loc) · 3.72 KB
/
im4MEC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
# Adapted from https://github.com/AIRMEC/im4MEC
class Attn_Net_Gated(nn.Module):
def __init__(self, L=1024, D=256, dropout=False, p_dropout_atn=0.25, n_classes=1):
super(Attn_Net_Gated, self).__init__()
self.attention_a = [nn.Linear(L, D), nn.Tanh()]
self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
if dropout:
self.attention_a.append(nn.Dropout(p_dropout_atn))
self.attention_b.append(nn.Dropout(p_dropout_atn))
self.attention_a = nn.Sequential(*self.attention_a)
self.attention_b = nn.Sequential(*self.attention_b)
self.attention_c = nn.Linear(D, n_classes)
def forward(self, x):
a = self.attention_a(x)
b = self.attention_b(x)
A = a.mul(b)
A = self.attention_c(A) # N x n_classes
return A
class Im4MEC(nn.Module):
def __init__(
self,
input_feature_size=1024,
precompression_layer=True,
feature_size_comp = 512,
feature_size_attn = 256,
dropout=True,
p_dropout_fc=0.25,
p_dropout_atn=0.25,
n_classes=4,
):
super(Im4MEC, self).__init__()
self.n_classes = n_classes
if precompression_layer:
self.compression_layer = nn.Sequential(*[
nn.Linear(input_feature_size, feature_size_comp*4),
nn.ReLU(),
nn.Dropout(p_dropout_fc),
nn.Linear(feature_size_comp*4, feature_size_comp*2),
nn.ReLU(),
nn.Dropout(p_dropout_fc),
nn.Linear(feature_size_comp*2, feature_size_comp),
nn.ReLU(),
nn.Dropout(p_dropout_fc)])
dim_post_compression = feature_size_comp
else:
self.compression_layer = nn.Identity()
dim_post_compression = input_feature_size
self.attention_net = Attn_Net_Gated(
L=dim_post_compression,
D=feature_size_attn,
dropout=dropout,
p_dropout_atn=p_dropout_atn,
n_classes=self.n_classes)
# Classification head.
self.classifiers = nn.ModuleList(
[nn.Linear(dim_post_compression, 1) for i in range(self.n_classes)]
)
# Init weights.
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
def forward_attention(self, h):
A_ = self.attention_net(h) # h shape is N_tilesxdim
A_raw = torch.transpose(A_, 1, 0) # K_attention_classesxN_tiles
A = F.softmax(A_raw, dim=-1) # normalize attentions scores over tiles
return A_raw, A
def forward(self, h):
h = self.compression_layer(h)
# Attention MIL.
A_raw, A = self.forward_attention(h) # 1xN tiles
M = A @ h #torch.Size([1, dim_embedding]) # 1x512 [Sum over N(aihi,1), ..., Sum over N(aihi,512)]
logits = torch.empty(1, self.n_classes).float().to(h.device)
for c in range(self.n_classes):
logits[0, c] = self.classifiers[c](M[c])
Y_hat = torch.topk(logits, 1, dim=1)[1]
Y_prob = F.softmax(logits, dim=1)
return logits, Y_prob, Y_hat, A_raw, M