-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper.py
82 lines (69 loc) · 2.69 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import zipfile
import requests
import math
import numpy as np
from PIL import Image
from tqdm import tqdm
DEBUG = False
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination, chunk_size=32*1024):
total_size = int(response.headers.get('content-length', 0))
with open(destination, "wb") as f:
for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
unit='B', unit_scale=True, desc=destination):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
def download_file_from_google_drive(id, destination):
print("Downloading into ./data/... Please wait.")
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={ 'id': id }, stream=True)
token = get_confirm_token(response)
if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
print("Done.")
def download_celeb_a():
dirpath = './data'
data_dir = 'celebA'
if os.path.exists(os.path.join(dirpath, data_dir)):
print('Found Celeb-A - skip')
return
filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
save_path = os.path.join(dirpath, filename)
if os.path.exists(save_path):
print('[*] {} already exists'.format(save_path))
else:
download_file_from_google_drive(drive_id, save_path)
if not DEBUG:
zip_dir = ''
with zipfile.ZipFile(save_path) as zf:
zip_dir = zf.namelist()[0]
zf.extractall(dirpath)
os.remove(save_path)
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
def images_square_grid(images, mode='RGB'):
"""
Helper function to save images as a square grid (visualization)
"""
# Get maximum size for square grid of images
save_size = math.floor(np.sqrt(images.shape[0]))
# Scale to 0-255
images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
# Put images in a square arrangement
images_in_square = np.reshape(
images[:save_size*save_size],
(save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
# Combine images to grid image
new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
for col_i, col_images in enumerate(images_in_square):
for image_i, image in enumerate(col_images):
im = Image.fromarray(image, mode)
new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
return new_im