-
Notifications
You must be signed in to change notification settings - Fork 6
/
data_load.py
95 lines (78 loc) · 3.09 KB
/
data_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json
import pandas as pd
from io import BytesIO
import base64
from PIL import Image
import argparse
from data_visualization import (
element_visual,
elements_visual,
actions_visual
)
def read_json(path):
with open(path, 'r', encoding='utf8') as f:
data = json.loads(f.read())
return data
def read_parquet(path):
return pd.read_parquet(path, columns=None)
def decode_base64_to_image(base64_string):
return Image.open(BytesIO(base64.b64decode(base64_string))).convert("RGB")
def read_image_from_qarquet(cur_df, image_id, b64decode=True):
cur_image_str = cur_df.loc[image_id]["base64"]
if b64decode:
return decode_base64_to_image(cur_image_str)
else:
return Image.open(BytesIO(cur_image_str)).convert("RGB")
if __name__ == "__main__":
"""
- guienv
- ocr_grounding_test_data.json
- ocr_grounding_test_images.parquet
- guiact
- web-single_test_data.json
- web-single_test_images.parquet
- web-multi_test_data.json
- web-multi_test_images.parquet
- smartphone_test_data.json
- smartphone_test_images.parquet
- guichat
- guichat_data.json
- guichat_images.parquet
"""
parser = argparse.ArgumentParser("")
parser.add_argument("--data_path", default="./data/guichat_data.json")
parser.add_argument("--img_path", default="./data/guichat_images.parquet")
parser.add_argument("--dataset", default="guichat")
args = parser.parse_args()
data = read_json(args.data_path)
cur_df = read_parquet(args.img_path)
if args.dataset == "guienv":
for sample in data:
image_id, question, answer = sample["image_id"], sample["question"], sample["answer"]
image = read_image_from_qarquet(cur_df, image_id)
if sample["task_type"] == "text2bbox":
img_with_box = element_visual(answer, image.copy(), question)
elif sample["task_type"] == "bbox2text":
img_with_box = element_visual(question, image.copy(), answer)
else:
print("unsupported task type.")
img_with_box.save("1.png")
breakpoint()
elif args.dataset == "guiact":
for sample in data:
image_id, question, actions = sample["image_id"], sample["question"], sample["actions_label"]
image = read_image_from_qarquet(cur_df, image_id)
img_with_actions = actions_visual(actions, image.copy(), question)
img_with_actions.save("1.png")
img_with_elements = elements_visual(cur_df.loc[image_id]["elements"], image.copy())
img_with_elements.save("2.png")
image.save("3.png")
breakpoint()
elif args.dataset == "guichat":
for sample in data:
image_id = sample["image_id"]
image = read_image_from_qarquet(cur_df, image_id, b64decode=False)
image.save("1.png")
breakpoint()
else:
print("unsupported dataset.")