-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathensemble.py
48 lines (41 loc) · 1.52 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
from collections import defaultdict
def main():
'''
outputs/test_dataset에 앙상블할 K개의 nbest_pred 파일을
nbest_predictions_1.json
nbest_predictions_2.json
...
로 저장
python ensemble.py 로 실행
outputs/test_dataset에 앙상블 결과 ensemble.json 생성
'''
K = 3
nbests = []
for _ in range(K):
with open(f"./outputs/test_dataset/nbest_predictions_{_ + 1}.json", "r") as file:
nbests.append(json.load(file))
d = defaultdict(list)
for nbest in nbests:
for id_, pred_list in nbest.items():
for pred in pred_list:
text = pred["text"]
probability = pred["probability"]
flag = 0
for voting_dist in d[id_]:
voting_text = voting_dist["text"]
if text == voting_text:
voting_dist["probability"] += probability
flag = 1
break
if not flag:
d[id_].append({"text": text, "probability": probability})
ensemble = {}
for id_ in d.keys():
ensemble[id_] = sorted(d[id_], key = lambda x:-x['probability'])[0]['text']
with open("./outputs/test_dataset/ensemble.json", "w", encoding="utf-8") as writer:
writer.write(
json.dumps(ensemble, indent=4, ensure_ascii=False) + "\n"
)
if __name__ == "__main__":
main()