-
Notifications
You must be signed in to change notification settings - Fork 23
/
fusion_strategy.py
95 lines (68 loc) · 2.69 KB
/
fusion_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import utils
EPSILON = 1e-5
# attention fusion strategy, average based on weight maps
def attention_fusion_weight(tensor1, tensor2, p_type):
# avg, max, nuclear
f_channel = channel_fusion(tensor1, tensor2, p_type)
f_spatial = spatial_fusion(tensor1, tensor2)
tensor_f = (f_channel + f_spatial) / 2
return tensor_f
# select channel
def channel_fusion(tensor1, tensor2, p_type):
# global max pooling
shape = tensor1.size()
# calculate channel attention
global_p1 = channel_attention(tensor1, p_type)
global_p2 = channel_attention(tensor2, p_type)
# get weight map
global_p_w1 = global_p1 / (global_p1 + global_p2 + EPSILON)
global_p_w2 = global_p2 / (global_p1 + global_p2 + EPSILON)
global_p_w1 = global_p_w1.repeat(1, 1, shape[2], shape[3])
global_p_w2 = global_p_w2.repeat(1, 1, shape[2], shape[3])
tensor_f = global_p_w1 * tensor1 + global_p_w2 * tensor2
return tensor_f
def spatial_fusion(tensor1, tensor2, spatial_type='mean'):
shape = tensor1.size()
# calculate spatial attention
spatial1 = spatial_attention(tensor1, spatial_type)
spatial2 = spatial_attention(tensor2, spatial_type)
# get weight map, soft-max
spatial_w1 = torch.exp(spatial1) / (torch.exp(spatial1) + torch.exp(spatial2) + EPSILON)
spatial_w2 = torch.exp(spatial2) / (torch.exp(spatial1) + torch.exp(spatial2) + EPSILON)
spatial_w1 = spatial_w1.repeat(1, shape[1], 1, 1)
spatial_w2 = spatial_w2.repeat(1, shape[1], 1, 1)
tensor_f = spatial_w1 * tensor1 + spatial_w2 * tensor2
return tensor_f
# channel attention
def channel_attention(tensor, pooling_type='avg'):
# global pooling
shape = tensor.size()
pooling_function = F.avg_pool2d
if pooling_type is 'attention_avg':
pooling_function = F.avg_pool2d
elif pooling_type is 'attention_max':
pooling_function = F.max_pool2d
elif pooling_type is 'attention_nuclear':
pooling_function = nuclear_pooling
global_p = pooling_function(tensor, kernel_size=shape[2:])
return global_p
# spatial attention
def spatial_attention(tensor, spatial_type='sum'):
spatial = []
if spatial_type is 'mean':
spatial = tensor.mean(dim=1, keepdim=True)
elif spatial_type is 'sum':
spatial = tensor.sum(dim=1, keepdim=True)
return spatial
# pooling function
def nuclear_pooling(tensor, kernel_size=None):
shape = tensor.size()
vectors = torch.zeros(1, shape[1], 1, 1).cuda()
for i in range(shape[1]):
u, s, v = torch.svd(tensor[0, i, :, :] + EPSILON)
s_sum = torch.sum(s)
vectors[0, i, 0, 0] = s_sum
return vectors