-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdump_encoded_imgs.py
63 lines (53 loc) · 2.14 KB
/
dump_encoded_imgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Load a dataset from parquet files and decode it to PNGs."""
import argparse
import numpy as np
import PIL.Image
import torch
import tqdm
from datasets import Dataset
from itertools import islice
from omegaconf import OmegaConf
from pathlib import Path
from tqdm import tqdm
from txt2img_unsupervised.ldm_autoencoder import LDMAutoencoder
from txt2img_unsupervised.load_pq_dir import load_pq_dir
import txt2img_unsupervised.ldm_autoencoder as ldm_autoencoder
parser = argparse.ArgumentParser()
parser.add_argument("--pq-dir", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--ae-cfg", type=Path, required=True)
parser.add_argument("--ae-ckpt", type=Path, required=True)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--res", type=int, required=True)
parser.add_argument("-n", type=int, required=True)
args = parser.parse_args()
assert args.res % 4 == 0
res_tokens = args.res // 4
assert args.n > 0
print(f"Loading autoencoder from {args.ae_ckpt}")
ae_cfg = OmegaConf.load(args.ae_cfg)["model"]["params"]
ae_mdl = LDMAutoencoder(ae_cfg)
ae_params = ae_mdl.params_from_torch(
torch.load(args.ae_ckpt, map_location="cpu"), ae_cfg
)
print(f"Loading dataset from {args.pq_dir}")
dset = load_pq_dir(args.pq_dir)
dset.set_format("numpy")
dset = dset.shuffle()
print(f"Found {len(dset)} images")
print(f"Decoding dataset to {args.output_dir}")
args.output_dir.mkdir(exist_ok=True, parents=True)
print(f"Batch size {args.batch_size}")
dset_iter = dset.iter(batch_size=args.batch_size, drop_last_batch=False)
batches_to_decode = args.n // args.batch_size + 1 if args.n % args.batch_size > 0 else 0
print(f"Decoding {batches_to_decode} batches")
dset_iter = islice(dset_iter, batches_to_decode)
with tqdm(total=args.n, unit="img") as pbar:
for batch in dset_iter:
imgs_j = ldm_autoencoder.decode_jv(
ae_mdl, ae_params, (res_tokens, res_tokens), batch["encoded_img"]
)
for img, name in zip(imgs_j, batch["name"]):
img = PIL.Image.fromarray(np.array(img))
img.save(args.output_dir / f"{name}.png")
pbar.update(len(imgs_j))