-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpers.py
28 lines (20 loc) · 940 Bytes
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf
# G = A^T*A
def gram_matrix(input_tensor):
# Get the shape of the input tensor
batch_size, height, width, channels = tf.shape(input_tensor)
# Reshape the tensor to (batch_size, num_pixels, channels)
# where num_pixels = height * width
num_pixels = height * width
reshaped_tensor = tf.reshape(input_tensor, (batch_size, num_pixels, channels))
# Compute the Gram matrix for each image in the batch
# Compute the Gram matrix G = A^T * A
# where A is the reshaped tensor
# G will have shape (batch_size, channels, channels)
gram_matrix = tf.linalg.matmul(reshaped_tensor, reshaped_tensor, transpose_a=True)
# Normalize the Gram matrix by the number of locations (pixels)
gram_matrix /= tf.cast(num_pixels, tf.float32)
return gram_matrix
# Since Image is flot
def clip_0_1(image):
return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)