-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMILQT.py
64 lines (50 loc) · 2.84 KB
/
MILQT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
MILQT - "Multiple interaction learning with question-type prior knowledge for constraining answer search space in visual question answering."
Do, Tuong, Binh X. Nguyen, Huy Tran, Erman Tjiputra, Quang D. Tran, and Thanh-Toan Do.
Our arxiv link: https://arxiv.org/abs/2009.11118
This code is written by Vuong Pham and Tuong Do.
"""
import torch
import torch.nn as nn
class MILQT(nn.Module):
def __init__(self, qt_model, models, question_type_mapping, combination_operator='add'):
super(MILQT, self).__init__()
self.question_type_model = qt_model
# initialize the models
self.models = nn.ModuleList(model for model in models)
self.num_models = len(models)
self.combination_operator = combination_operator
self.pred_combining_layer = nn.Linear(self.num_models, 1, bias=False)
self.question_type_mapping = question_type_mapping
self.features = [0]*self.num_models
self.features_combined = [0]*self.num_models
self.preds = [0]*len(self.models)
self.preds_combined = [0]*self.num_models
def forward(self, visuals, boxes, questions):
# Do forward pass of every model
question_emb = self.question_type_model(questions) # b x 1024
for idx, model in enumerate(self.models):
self.features[idx] = model(visuals, boxes, questions)
question_type_preds = self.question_type_model.classify(question_emb)
for idx, model in enumerate(self.models):
self.preds[idx] = model.classify(self.features[idx])
for idx, model in enumerate(self.models):
if self.combination_operator == 'add':
self.features_combined[idx] = self.features[idx] + question_emb
if self.combination_operator == 'mul':
self.features_combined[idx] = self.features[idx] * question_emb
for idx, model in enumerate(self.models):
self.preds_combined[idx] = model.classify(self.features_combined[idx])
# Do weight predictions of models
model_preds = torch.cat([self.preds[idx].unsqueeze(2) for idx in range(self.num_models)], 2)
model_preds = self.pred_combining_layer(model_preds).squeeze(2)
model_preds_combined = torch.cat([self.preds_combined[idx].unsqueeze(2) for idx in range(self.num_models)], 2)
model_preds_combined = self.pred_combining_layer(model_preds_combined).squeeze(2)
# Get question type of questions of the current batch
_, question_types = question_type_preds.max(1)
# Weighting awareness
mask = torch.zeros(model_preds.shape, device=model_preds.device)
for i in range(len(question_types)):
question_type = question_types[i]
mask[i] = self.question_type_mapping[question_type]
return self.preds_combined, model_preds, model_preds_combined, question_type_preds, mask