-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluation.py
131 lines (108 loc) · 4.24 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Utility functions for computing FID/Inception scores."""
import jax
import numpy as np
import six
import tensorflow as tf
import tensorflow_gan as tfgan
import tensorflow_hub as tfhub
INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'
INCEPTION_OUTPUT = 'logits'
INCEPTION_FINAL_POOL = 'pool_3'
_DEFAULT_DTYPES = {
INCEPTION_OUTPUT: tf.float32,
INCEPTION_FINAL_POOL: tf.float32
}
INCEPTION_DEFAULT_IMAGE_SIZE = 299
def get_inception_model(inceptionv3=False):
if inceptionv3:
return tfhub.load(
'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4')
else:
return tfhub.load(INCEPTION_TFHUB)
def load_dataset_stats(config):
"""Load the pre-computed dataset statistics."""
if config.data.dataset == 'CIFAR10':
filename = 'assets/stats/cifar10_stats.npz'
elif config.data.dataset == 'CELEBA':
filename = 'assets/stats/celeba_stats.npz'
elif config.data.dataset == 'LSUN':
filename = f'assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz'
else:
raise ValueError(f'Dataset {config.data.dataset} stats not found.')
with tf.io.gfile.GFile(filename, 'rb') as fin:
stats = np.load(fin)
return stats
def classifier_fn_from_tfhub(output_fields, inception_model,
return_tensor=False):
"""Returns a function that can be as a classifier function.
Copied from tfgan but avoid loading the model each time calling _classifier_fn
Args:
output_fields: A string, list, or `None`. If present, assume the module
outputs a dictionary, and select this field.
inception_model: A model loaded from TFHub.
return_tensor: If `True`, return a single tensor instead of a dictionary.
Returns:
A one-argument function that takes an image Tensor and returns outputs.
"""
if isinstance(output_fields, six.string_types):
output_fields = [output_fields]
def _classifier_fn(images):
output = inception_model(images)
if output_fields is not None:
output = {x: output[x] for x in output_fields}
if return_tensor:
assert len(output) == 1
output = list(output.values())[0]
return tf.nest.map_structure(tf.compat.v1.layers.flatten, output)
return _classifier_fn
@tf.function
def run_inception_jit(inputs,
inception_model,
num_batches=1,
inceptionv3=False):
"""Running the inception network. Assuming input is within [0, 255]."""
if not inceptionv3:
inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5
else:
inputs = tf.cast(inputs, tf.float32) / 255.
return tfgan.eval.run_classifier_fn(
inputs,
num_batches=num_batches,
classifier_fn=classifier_fn_from_tfhub(None, inception_model),
dtypes=_DEFAULT_DTYPES)
@tf.function
def run_inception_distributed(input_tensor,
inception_model,
num_batches=1,
inceptionv3=False):
"""Distribute the inception network computation to all available TPUs.
Args:
input_tensor: The input images. Assumed to be within [0, 255].
inception_model: The inception network model obtained from `tfhub`.
num_batches: The number of batches used for dividing the input.
inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1.
Returns:
A dictionary with key `pool_3` and `logits`, representing the pool_3 and
logits of the inception network respectively.
"""
num_tpus = jax.local_device_count()
input_tensors = tf.split(input_tensor, num_tpus, axis=0)
pool3 = []
logits = [] if not inceptionv3 else None
device_format = '/TPU:{}' if 'TPU' in str(jax.devices()[0]) else '/GPU:{}'
for i, tensor in enumerate(input_tensors):
with tf.device(device_format.format(i)):
tensor_on_device = tf.identity(tensor)
res = run_inception_jit(
tensor_on_device, inception_model, num_batches=num_batches,
inceptionv3=inceptionv3)
if not inceptionv3:
pool3.append(res['pool_3'])
logits.append(res['logits']) # pytype: disable=attribute-error
else:
pool3.append(res)
with tf.device('/CPU'):
return {
'pool_3': tf.concat(pool3, axis=0),
'logits': tf.concat(logits, axis=0) if not inceptionv3 else None
}