-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathautohall_gemini.py
99 lines (76 loc) · 2.78 KB
/
autohall_gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import google.generativeai as genai
# from google.colab import userdata
from IPython.display import display
from IPython.display import Markdown
from PIL import Image
from tqdm import tqdm
import json
import os
import time
GOOGLE_API_KEY="Your Key"
genai.configure(api_key=GOOGLE_API_KEY)
for m in genai.list_models():
print(m.name)
print(m.supported_generation_methods)
model = genai.GenerativeModel('gemini-1.5-flash-latest')
root = "/home/rayguan/Desktop/trustAGI/TrustAGI-anno-platform/data"
json_in = "/home/rayguan/Desktop/trustAGI/TrustAGI-anno-platform/data/autohallusion_data.json"
json_out = "/home/rayguan/Desktop/trustAGI/TrustAGI-anno-platform/data/autohallusion_data_gemini_res.json"
json_tmp = "/home/rayguan/Desktop/trustAGI/TrustAGI-anno-platform/data/autohallusion_data_gemini_res_resume.json"
if os.path.exists(json_tmp):
json_in = json_tmp
with open(json_in, 'r') as f:
data = json.load(f)
# print(len(data))
count = 0
attempt = 10
for i, element in enumerate(tqdm(data)):
# print(i)
if "res" in element and element['res'] != 'null':
continue
attm_count = 0
while attm_count < attempt:
try:
question = data[i]['prompt']
# print(question)
if "image_urls" not in element:
response = model.generate_content([question])
result = response.text
element['res'] = result
else:
img_path = os.path.join(root, element["image_urls"][0])
# img_path = "/home/rayguan/Desktop/trustAGI/TrustAGI-anno-platform/data/" + element['image_urls'][0]
raw_image = Image.open(img_path)
response = model.generate_content([question, raw_image])
result = response.text
element['res'] = result
# print(result)
count += 1
break
except:
print("Timeout, Retrying...")
print(response.prompt_feedback)
time.sleep(5)
attm_count += 1
# if attm_count >= attempt:
element['res'] = 'null'
# try:
# question = data[i]['prompt']
# response = model.generate_content([question])
# result = response.text
# element['res'] = result
# print(result)
# except:
# element['res'] = 'null'
# print('null')
# count += 1
if i % 10 == 0:
save_dict_tmp = json.dumps(data, indent=4)
with open(json_tmp, "w") as out_f:
out_f.write(save_dict_tmp)
print("Progress saved.")
print(len(data) - count)
# print(i)
output_json = json.dumps(data, indent=4)
with open(json_out, 'w') as f:
f.write(output_json)