-
Notifications
You must be signed in to change notification settings - Fork 0
/
render_all.py
44 lines (39 loc) · 1.36 KB
/
render_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import utils
import json
import os
import argparse
import cairosvg
import PIL.Image
import io
import tqdm
def default_argument_parser():
parser = argparse.ArgumentParser(description="render the original vector graphics in the dataset")
parser.add_argument(
"--format", choices=["svg", "tikz", "graphviz"], default="", required=True, help="the format of the vector graphics")
parser.add_argument("--dataset-path", required=True)
return parser
def main():
args = default_argument_parser().parse_args()
dataset = json.load(open(args.dataset_path))
for idx in tqdm.tqdm(list(range(0, len(dataset)))):
sample = dataset[idx]
code = sample['code']
caption = sample['caption']
out_file_path = os.path.join("pngs/%s"%args.format, "%d.png"%idx)
if os.path.exists(out_file_path):
continue
# print(caption)
if args.format == "svg":
png_bytes = cairosvg.svg2png(code, background_color="white")
img = PIL.Image.open(io.BytesIO(png_bytes))
elif args.format == "tikz":
img = utils.render_tikz(code)
elif args.format =="graphviz":
img = utils.render_graphviz(code)
else:
raise "Unknown format"
if img == None:
continue
img.save(out_file_path)
if __name__ == '__main__':
main()