forked from thuiar/MMSA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SELF_MM.py
140 lines (120 loc) · 5.79 KB
/
SELF_MM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
From: https://github.com/thuiar/Self-MM
Paper: Learning Modality-Specific Representations with Self-Supervised Multi-Task Learning for Multimodal Sentiment Analysis
"""
# self supervised multimodal multi-task learning network
import os
import sys
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from models.subNets.BertTextEncoder import BertTextEncoder
__all__ = ['SELF_MM']
class SELF_MM(nn.Module):
def __init__(self, args):
super(SELF_MM, self).__init__()
# text subnets
self.aligned = args.need_data_aligned
self.text_model = BertTextEncoder(language=args.language, use_finetune=args.use_finetune)
# audio-vision subnets
audio_in, video_in = args.feature_dims[1:]
self.audio_model = AuViSubNet(audio_in, args.a_lstm_hidden_size, args.audio_out, \
num_layers=args.a_lstm_layers, dropout=args.a_lstm_dropout)
self.video_model = AuViSubNet(video_in, args.v_lstm_hidden_size, args.video_out, \
num_layers=args.v_lstm_layers, dropout=args.v_lstm_dropout)
# the post_fusion layers
self.post_fusion_dropout = nn.Dropout(p=args.post_fusion_dropout)
self.post_fusion_layer_1 = nn.Linear(args.text_out + args.video_out + args.audio_out, args.post_fusion_dim)
self.post_fusion_layer_2 = nn.Linear(args.post_fusion_dim, args.post_fusion_dim)
self.post_fusion_layer_3 = nn.Linear(args.post_fusion_dim, 1)
# the classify layer for text
self.post_text_dropout = nn.Dropout(p=args.post_text_dropout)
self.post_text_layer_1 = nn.Linear(args.text_out, args.post_text_dim)
self.post_text_layer_2 = nn.Linear(args.post_text_dim, args.post_text_dim)
self.post_text_layer_3 = nn.Linear(args.post_text_dim, 1)
# the classify layer for audio
self.post_audio_dropout = nn.Dropout(p=args.post_audio_dropout)
self.post_audio_layer_1 = nn.Linear(args.audio_out, args.post_audio_dim)
self.post_audio_layer_2 = nn.Linear(args.post_audio_dim, args.post_audio_dim)
self.post_audio_layer_3 = nn.Linear(args.post_audio_dim, 1)
# the classify layer for video
self.post_video_dropout = nn.Dropout(p=args.post_video_dropout)
self.post_video_layer_1 = nn.Linear(args.video_out, args.post_video_dim)
self.post_video_layer_2 = nn.Linear(args.post_video_dim, args.post_video_dim)
self.post_video_layer_3 = nn.Linear(args.post_video_dim, 1)
def forward(self, text, audio, video):
audio, audio_lengths = audio
video, video_lengths = video
mask_len = torch.sum(text[:,1,:], dim=1, keepdim=True)
text_lengths = mask_len.squeeze().int().detach().cpu()
text = self.text_model(text)[:,0,:]
if self.aligned:
audio = self.audio_model(audio, text_lengths)
video = self.video_model(video, text_lengths)
else:
audio = self.audio_model(audio, audio_lengths)
video = self.video_model(video, video_lengths)
# fusion
fusion_h = torch.cat([text, audio, video], dim=-1)
fusion_h = self.post_fusion_dropout(fusion_h)
fusion_h = F.relu(self.post_fusion_layer_1(fusion_h), inplace=False)
# # text
text_h = self.post_text_dropout(text)
text_h = F.relu(self.post_text_layer_1(text_h), inplace=False)
# audio
audio_h = self.post_audio_dropout(audio)
audio_h = F.relu(self.post_audio_layer_1(audio_h), inplace=False)
# vision
video_h = self.post_video_dropout(video)
video_h = F.relu(self.post_video_layer_1(video_h), inplace=False)
# classifier-fusion
x_f = F.relu(self.post_fusion_layer_2(fusion_h), inplace=False)
output_fusion = self.post_fusion_layer_3(x_f)
# classifier-text
x_t = F.relu(self.post_text_layer_2(text_h), inplace=False)
output_text = self.post_text_layer_3(x_t)
# classifier-audio
x_a = F.relu(self.post_audio_layer_2(audio_h), inplace=False)
output_audio = self.post_audio_layer_3(x_a)
# classifier-vision
x_v = F.relu(self.post_video_layer_2(video_h), inplace=False)
output_video = self.post_video_layer_3(x_v)
res = {
'M': output_fusion,
'T': output_text,
'A': output_audio,
'V': output_video,
'Feature_t': text_h,
'Feature_a': audio_h,
'Feature_v': video_h,
'Feature_f': fusion_h,
}
return res
class AuViSubNet(nn.Module):
def __init__(self, in_size, hidden_size, out_size, num_layers=1, dropout=0.2, bidirectional=False):
'''
Args:
in_size: input dimension
hidden_size: hidden layer dimension
num_layers: specify the number of layers of LSTMs.
dropout: dropout probability
bidirectional: specify usage of bidirectional LSTM
Output:
(return value in forward) a tensor of shape (batch_size, out_size)
'''
super(AuViSubNet, self).__init__()
self.rnn = nn.LSTM(in_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.linear_1 = nn.Linear(hidden_size, out_size)
def forward(self, x, lengths):
'''
x: (batch_size, sequence_len, in_size)
'''
packed_sequence = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
_, final_states = self.rnn(packed_sequence)
h = self.dropout(final_states[0].squeeze())
y_1 = self.linear_1(h)
return y_1