-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_attr.py
executable file
·80 lines (63 loc) · 2.28 KB
/
get_attr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 17 21:33:42 2020
@author: wuzongze
Modified on Wed Jul 20 11:29 2022
@author Susmit-A
"""
import os
import dnnlib.tflib as tflib
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import glob
from tqdm import tqdm
import pickle as pkl
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
Can be used as an input transformation for Network.run().
"""
if nhwc_to_nchw:
imgs_roll=np.rollaxis(images, 3, 1)
return imgs_roll/ 255 *(drange[1] - drange[0])+ drange[0]
#%%
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='predict pose of object')
parser.add_argument('--img_path',type=str,help='path to image folder')
parser.add_argument('--save_path',type=str,help='path to save attribute file')
parser.add_argument('--classifier_path',default='./attr_models',type=str,help='path to a folder of classifers')
opt = parser.parse_args()
img_path=opt.img_path
save_path=opt.save_path
classifer_path=opt.classifier_path
imgs = glob.glob(os.path.join(opt.img_path, '*.png'))
names_tmp=os.listdir(classifer_path)
names=[]
for name in names_tmp:
if 'celebahq-classifier' in name:
names.append(name)
names.sort()
tflib.init_tf()
classifiers = []
print("Initializing models")
for name in tqdm(names):
tmp=os.path.join(classifer_path,name)
with open(tmp, 'rb') as f:
classifier = pkl.load(f)
classifiers.append((name, classifier))
results={}
for file in tqdm(imgs):
img_name = file.split('/')[-1]
results[img_name] = {}
for (name, model) in classifiers:
img = np.array(Image.open(file))
tmp_imgs = np.stack(np.split(img, 11, axis=1), axis=0)
tmp_imgs=convert_images_from_uint8(tmp_imgs, drange=[-1,1], nhwc_to_nchw=True)
tmp = model.run(tmp_imgs, None)
tmp1 = tmp.reshape(-1)
results[img_name][name] = tmp1
with open(opt.save_path, 'wb') as f:
pkl.dump(results, f)