-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathatomic_fact_generation_for_eval.py
165 lines (135 loc) · 6.66 KB
/
atomic_fact_generation_for_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from openai import OpenAI
import os
import tqdm
import time
import sys
from collections import Counter
import json
import argparse
def most_frequent_element(input_list):
return Counter(input_list).most_common(1)[0][0] if input_list else None
# Read the config file to get the openai key:
with open("config.json", "r") as f:
config = json.load(f)
openai_key = config["openai_key"]
client = OpenAI(
# This is the default and can be omitted
api_key=openai_key
)
def send_message(message):
model = "gpt-3.5-turbo"
response = client.chat.completions.create(
model=model,
messages=message
)
text = response.choices[0].message.content
return text
def check_implication(claim, text, passes=4):
votes = [gpt_check_implication(claim, text) for _ in range(passes)]
majority = most_frequent_element(votes)
return majority
def gpt_check_implication(question, text):
# TODO replace with an NLI model
system_message = "You are FactCheckGPT, a world-class tool used by journalists to discover problems in their writings. Users give you text, and check whether facts are true given the text. You ALWAYS answer either TRUE, FALSE, or NOT ENOUGH EVIDENCE."
prompt = "You will be given a snippet written as part of a source criticism exercise, and a claim. Your task is to determine whether the claim is true based ONLY on the text. Do NOT use any other knowledge source\n\n"
prompt += "The claim is: \"" + question + "\".\n"
prompt += "The text follows below:\n\"" + text + "\".\n\n"
prompt += question + " Thinking step by step, answer either TRUE, FALSE, or NOT ENOUGH EVIDENCE, capitalizing all letters. Explain your reasoning FIRST, and after that output either TRUE, FALSE, or NOT ENOUGH EVIDENCE."
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": prompt},
]
retries = 0
result = None
while result is None:
try:
result = send_message(messages)
except:
print(f"Failed to send message. Retrying (attempt {retries + 1})...", file=sys.stderr)
delay = min(1 * (2 ** retries), 64)
time.sleep(delay)
retries += 1
if "TRUE" in result:
return "entailment"
elif "FALSE" in result:
return "contradiction"
else:
return "neutral"
def generate_atomic_fact(question, text):
system_message = "You are InfoHuntGPT, a world-class tool used by journalists to quickly extract claims from text."
prompt = "You will be given a snippet written as part of a source criticism exercise, and a fill-in-the-blank question (blanks represented by _). Your task is to fill in the blanks in the sentence, adding no additional information or wording. JUST replace the _ character. No yapping.\n\n"
prompt += "The question is:\n" + question + "\n\n"
prompt += "The text follows below:\n\"" + text + "\".\n\n"
prompt += "Fill in the blanks in the question, adding no additional information or wording. JUST replace the _ character, and output ONLY the question with the blank filled in. No yapping."
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": prompt},
]
retries = 0
result = None
while result is None:
try:
result = send_message(messages)
except:
print(f"Failed to send message. Retrying (attempt {retries + 1})...", file=sys.stderr)
delay = min(1 * (2 ** retries), 64)
time.sleep(delay)
retries += 1
return result
def process_files(questions_file, folder_path, dataset_file, output_folder, start_at=0):
if not os.path.exists(output_folder):
os.makedirs(output_folder)
with open(dataset_file, 'r') as f:
dataset_lines = [line.split("\t")[0].strip() for line in f.readlines()]
# Load the JSON file
with open(questions_file, 'r') as file:
data = json.load(file)
questions = [item['statement'].strip() for item in data]
# Iterate through each text file in the folder
print("Starting at " + str(start_at))
for i, filename in enumerate(tqdm.tqdm(dataset_lines)):
if i < start_at:
pass
else:
file_path = os.path.join(folder_path, filename)
# Read content from the text file
with open(file_path, 'r') as file:
text_content = file.read()
parts = text_content.split("\nHistory\n")
text_to_check = "History\n" + "\nHistory\n".join(parts[1:]) if len(parts) > 1 else ""
text_to_check = text_to_check.replace("[Media Bias Fact Check]()", "").split("Last Updated")[0]
text_to_check = text_to_check.replace("Mediabiasfactcheck.com", "")
# Initialize a list to store results for each question and text combination
atomic_facts = []
outlet = filename[:-4]
# Iterate through each question
for question in questions:
localized_question = question.replace("X", outlet)
if "_" not in question:
result = localized_question
else:
result = generate_atomic_fact(localized_question, text_to_check).replace("_", "")
# Basic check: No newlines, no tabs, at least as many words as the original
result = result.split("\n")[-1].replace("\t", " ")
if len(result.split()) < len(question.split()):
continue
entail_pos = check_implication(result, text_to_check)
if entail_pos != "neutral":
atomic_facts.append("\t".join([result, entail_pos]))
# Append the list of results for the current file to the main results list
output_file = os.path.join(output_folder, filename)
with open(output_file, 'w') as f:
f.write("\n".join(atomic_facts))
parser = argparse.ArgumentParser(description='Generate atomic facts for evaluation')
parser.add_argument('--start_at', type=int, default=0)
parser.add_argument('--dataset_file', type=str, default="data/splits/dev.tsv")
parser.add_argument('--query_file', type=str, default="data/queries.json")
parser.add_argument('--fact_folder', type=str, default="data/splits/dev_facts")
parser.add_argument('--reference_folder', type=str, default="data/mbcs")
args = parser.parse_args()
start_at = args.start_at
dataset_file = args.dataset_file
questions_file_path = args.query_file
data_folder_path = args.reference_folder
fact_folder = args.fact_folder
process_files(questions_file_path, data_folder_path, dataset_file, fact_folder, start_at=start_at)