-
Notifications
You must be signed in to change notification settings - Fork 0
/
scorer.py
112 lines (88 loc) · 4.04 KB
/
scorer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import glob
import json
import os
import logging
import hydra
import torch
import tqdm
from transformers import set_seed
from accelerate import Accelerator
from inferencer import Inferencer
from open_flamingo.src.dataset_readers.scoring_dsr import ScorerDatasetReader
from open_flamingo.src.util.misc import save_json
logger = logging.getLogger(__name__)
class Scorer(Inferencer):
def forward(self):
if self.accelerator.is_main_process:
dataloader = tqdm.tqdm(self.dataloader)
else:
dataloader = self.dataloader
res = []
for i, entry in enumerate(dataloader):
metadata = entry.pop("metadata")
with torch.no_grad():
output = self.model(vision_x=entry.images, lang_x=entry.input_ids, attention_mask=entry.attention_mask)
loss = self.nll_loss(entry=entry, output=output)
for mdata, loss in zip(metadata, loss):
mdata['score'] = loss
res.extend(metadata)
with open(f"{self.output_file}tmp_{self.accelerator.device}.bin", "w") as f:
json.dump(res, f)
def nll_loss(self, entry, output):
shift_logits = output.logits[..., :-1, :].contiguous()
shift_labels = entry['input_ids'][..., 1:].contiguous()
pad_token_id = self.dataset_reader.tokenizer.pad_token_id
# entry.labels is already padded with pad_token_id, we further pad it to full length
pad_mask = torch.nn.functional.pad(entry['labels'],
(shift_labels.shape[-1] - entry['labels'].shape[-1], 0),
value=pad_token_id).to(self.device)
shift_labels.masked_fill_(pad_mask == pad_token_id, pad_token_id)
loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=pad_token_id)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)).view(shift_labels.size())
answer_lens = (entry['labels'] != pad_token_id).sum(-1)
loss = loss.sum(-1) / answer_lens
loss = loss.cpu().detach().numpy().tolist()
return loss
def write_results(self):
data = []
for i, path in enumerate(glob.glob(f"{self.output_file}tmp_*.bin")):
with open(path) as f:
one_device = json.load(f)
logger.info(f"device: {i}, idx {[i['idx'] for i in one_device][:200]}...")
data.extend(one_device)
# grouping results by uid
example_dict = {}
uid_field = 'idx'
for entry in data:
ctxs = {"ctxs": entry.pop('ctxs'), "score": entry.pop("score")}
if entry[uid_field] not in example_dict:
entry['ctxs_candidates'] = [ctxs]
example_dict[entry[uid_field]] = entry
else:
example_dict[entry[uid_field]]['ctxs_candidates'].append(ctxs)
example_list = list(example_dict.values())
mrr = 0
num_candidates = len(example_list[0]['ctxs_candidates'])
for entry in example_list:
assert len(entry['ctxs_candidates']) == num_candidates, f"{len(entry['ctxs_candidates'])}!={num_candidates}"
sorted_tuple = sorted(enumerate(entry['ctxs_candidates']), key=lambda x: x[1]['score'])
entry['ctxs_candidates'] = [i[1]['ctxs'] for i in sorted_tuple]
entry['ctxs'] = entry['ctxs_candidates'][0] # set top-scored cand to ctxs
mrr += 1 / ([i[0] for i in sorted_tuple].index(0) + 1)
logger.info(f"MRR: {mrr / len(example_list)}")
save_json(self.output_file, example_list)
for path in glob.glob(f"{self.output_file}tmp_*.bin"):
os.remove(path)
@hydra.main(config_path="open_flamingo/config", config_name="scorer")
def main(cfg):
logger.info(cfg)
set_seed(43)
accelerator = Accelerator()
scorer = Scorer(cfg, accelerator)
scorer.forward()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
scorer.write_results()
if __name__ == "__main__":
main()