forked from alu222/SIFDriveNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AVENet1.py
117 lines (85 loc) · 3.12 KB
/
AVENet1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from image_convnet import *
from audio_convnet import *
from torch.nn.parameter import Parameter
from utils.mydata_xu import *
import torch.nn.functional as F
class AVENet1(nn.Module):
def __init__(self,Bias,f,R,h):
super(AVENet1, self).__init__()
# 设置参数
# self.R = 4
# self.h = 8
self.w_img = Parameter(torch.randn(R,129,h)) # (4,129,8)
self.w_aud = Parameter(torch.randn(R,129,h))
self.w_f = Parameter(f)
self.bias = Parameter(Bias)
self.relu = F.relu
self.imgnet = ImageConvNet()
self.audnet = AudioConvNet()
self.vpool4 = nn.MaxPool2d(14, stride=14)
self.vfc1 = nn.Linear(512, 128)
self.vfc2 = nn.Linear(128, 128)
self.vl2norm = nn.BatchNorm1d(128)
self.apool4 = nn.MaxPool2d((16, 12), stride=(16, 12))
self.afc1 = nn.Linear(512, 128)
self.afc2 = nn.Linear(128, 128)
self.al2norm = nn.BatchNorm1d(128)
# Combining layers
self.mse = F.mse_loss
#self.fc3 = nn.Linear(1, 2)
self.fc3 = nn.Linear(8, 3)
self.softmax = F.softmax
def forward(self, image, audio):
# Image
img = self.imgnet(image)
img = self.vpool4(img).squeeze(2).squeeze(2)
img = self.relu(self.vfc1(img))
img = self.vfc2(img)
img = self.vl2norm(img)
aud = self.audnet(audio)
aud = self.apool4(aud).squeeze(2).squeeze(2)
aud = self.relu(self.afc1(aud))
aud = self.afc2(aud)
aud = self.al2norm(aud)
# join
n = img.shape[0] # 获取批次大小
img = img.cuda()
img = torch.cat([img,torch.ones(n,1).cuda()],dim=1) # (16*129)
aud = torch.cat([aud,torch.ones(n,1).cuda()],dim=1)
fusion_img = torch.matmul(img,self.w_img) # (4,16,8)
fusion_aud = torch.matmul(aud,self.w_aud)
fusion_img_aud = fusion_img * fusion_aud # (4,16,8) 对应位置相乘
# print(self.w_f.shape) # (1,8)
# print(fusion_img_aud.permute(1,0,2).shape) # (16,4,8)
fusion_img_aud = torch.matmul(self.w_f,fusion_img_aud.permute(1,0,2)).squeeze() + self.bias
out = self.fc3(fusion_img_aud)
# print(out)
# out = self.softmax(out, 1)
# print(out)
# print(type(out))
# # Join them
# mse = self.mse(img, aud, reduce=False).mean(1).unsqueeze(1)
# out = self.fc3(mse)
# out = self.softmax(out, 1)#对每一行进行softmax
return out, img, aud
def get_image_embeddings(self, image):
# Just get the image embeddings
img = self.imgnet(image)
img = self.vpool4(img).squeeze(2).squeeze(2)
img = self.relu(self.vfc1(img))
img = self.vfc2(img)
img = self.vl2norm(img)
return img
if __name__ == '__main__':
Bias = torch.tensor([[1.8062e-25, 7.3008e-43, 1.8062e-25, 7.3008e-43, 3.2415e-24, 7.3008e-43,
1.8062e-25, 7.3008e-43]], requires_grad=True).cuda()
f = torch.tensor([[0., 0., 0., 0.]], requires_grad=True).cuda()
model = AVENet1(Bias,f,4,8).cuda()
image = Variable(torch.rand(2, 3, 224, 224)).cuda()
speed = Variable(torch.rand(2, 1, 257, 200)).cuda()
# Run a feedforward and check shape
o,_,_ = model(image,speed)
W_img = model.w_img
W_aud = model.w_aud
W_f = model.w_f
Bias = model.bias