-
Notifications
You must be signed in to change notification settings - Fork 48
/
evaluate.py
141 lines (127 loc) · 5.8 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import re
import pandas as pd
from tqdm import tqdm
from ckbqa.utils.logger import logging_config
from config import ResultSaver
def train_data():
logging_config('train_evaluate.log', stream_log=True)
from ckbqa.models.evaluation_matrics import get_metrics
#
partten = re.compile(r'["<](.*?)[>"]')
#
_paths = ResultSaver(find_exist_path=True).train_result_csv
print(_paths)
train_df = pd.read_csv(_paths[0])
ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], []
answer_precisions, answer_recalls, answer_f1_scores = [], [], []
for index, row in tqdm(train_df.iterrows(), total=train_df.shape[0], desc='evaluate '):
subject_entities = partten.findall(row['standard_subject_entities']) # 匹配文字
if not subject_entities:
subject_entities = eval(row['standard_subject_entities'])
# 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉;
# 所以需要匹配文字,不匹配 <>, ""
# CEG Candidate Entity Generation
candidate_entities = eval(row['candidate_entities']) + partten.findall(row['candidate_entities'])
precision, recall, f1 = get_metrics(subject_entities, candidate_entities)
ceg_precisions.append(precision)
ceg_recalls.append(recall)
ceg_f1_scores.append(f1)
# Answer
standard_entities = eval(row['standard_answer_entities'])
result_entities = eval(row['result_entities'])
precision, recall, f1 = get_metrics(standard_entities, result_entities)
answer_precisions.append(precision)
answer_recalls.append(recall)
answer_f1_scores.append(f1)
#
# print(f"question: {row['question']}\n"
# f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}"
# f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
# import time
# time.sleep(2)
ave_ceg_precision = sum(ceg_precisions) / len(ceg_precisions)
ave_ceg_recall = sum(ceg_recalls) / len(ceg_recalls)
ave_ceg_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores)
print(f"ave_ceg_precision: {ave_ceg_precision:.3f}, "
f"ave_ceg_recall: {ave_ceg_recall:.3f}, "
f"ave_ceg_f1_score:{ave_ceg_f1_score:.3f}")
#
ave_answer_precision = sum(answer_precisions) / len(answer_precisions)
ave_answer_recall = sum(answer_recalls) / len(answer_recalls)
ave_answer_f1_score = sum(answer_f1_scores) / len(answer_f1_scores)
print(f"ave_result_precision: {ave_answer_precision:.3f}, "
f"ave_result_recall: {ave_answer_recall:.3f}, "
f"ave_result_f1_score:{ave_answer_f1_score:.3f}")
def ceg():
logging_config('train_evaluate.log', stream_log=True)
from ckbqa.models.evaluation_matrics import get_metrics
from ckbqa.qa.el import CEG
from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern #
ceg = CEG() # Candidate Entity Generation
ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], []
ceg_csv = "./ceg.csv"
data = []
for q, sparql, a in load_data(tqdm_prefix='ceg evaluate '):
q_entities = entity_pattern.findall(sparql) + attr_pattern.findall(sparql)
q_text = ''.join(question_patten.findall(q))
# 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉;
# 所以需要匹配文字,不匹配 <>, ""
ent2mention = ceg.get_ent2mention(q_text)
# CEG Candidate Entity Generation
precision, recall, f1 = get_metrics(q_entities, ent2mention)
ceg_precisions.append(precision)
ceg_recalls.append(recall)
ceg_f1_scores.append(f1)
#
data.append([q, q_entities, list(ent2mention.keys())])
if recall == 0:
# ceg.memory.entity2id
# ceg.memory.mention2entity
print(f"question: {q}\n"
f"subject_entities: {q_entities}, candidate_entities: {ent2mention}"
f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
# import ipdb
# ipdb.set_trace()
print('\n\n')
# import time
# time.sleep(2)
pd.DataFrame(data, columns=['question', 'q_entities', 'ceg']).to_csv(
ceg_csv, index=False, encoding='utf_8_sig')
ave_precision = sum(ceg_precisions) / len(ceg_precisions)
ave_recall = sum(ceg_recalls) / len(ceg_recalls)
ave_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores)
print(f"ave_precision: {ave_precision:.3f}, "
f"ave_recall: {ave_recall:.3f}, "
f"ave_f1_score:{ave_f1_score:.3f}")
def main():
parser = argparse.ArgumentParser(description="基础,通用parser")
# logging config 日志配置
parser.add_argument('--stream_log', action="store_true", help="是否将日志信息输出到标准输出") # log print到屏幕
#
group = parser.add_mutually_exclusive_group(required=True) # 一组互斥参数,且至少需要互斥参数中的一个
group.add_argument('--ceg', action="store_true", help="ceg Candidate Entity Generation评价")
group.add_argument('--train_data', action="store_true", help="train_answer_data评价")
# parse args
args = parser.parse_args()
#
# from ckbqa.utils.tools import ProcessManager
# ProcessManager().run()
if args.ceg:
ceg()
elif args.train_data:
train_data()
elif args.task:
task()
def task():
logging_config('ceg.log', stream_log=True)
ceg()
if __name__ == '__main__':
"""
example:
nohup python qa.py --train_data &>train_data.out &
nohup python qa.py --ceg &>ceg.out &
"""
# from ckbqa.utils.tools import ProcessManager #实时查看内存占用情况
# ProcessManager().run()
main()