-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathvisualize.py
152 lines (114 loc) · 5.04 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*- coding: utf-8 -*-
import numpy as np
from matplotlib import pyplot as plt
import scipy
import scipy.stats
from imageio import imsave
import cv2
def concat_images(images, image_width, spacer_size):
""" Concat image horizontally with spacer """
spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255
images_with_spacers = []
image_size = len(images)
for i in range(image_size):
images_with_spacers.append(images[i])
if i != image_size - 1:
# Add spacer
images_with_spacers.append(spacer)
ret = np.hstack(images_with_spacers)
return ret
def concat_images_in_rows(images, row_size, image_width, spacer_size=4):
""" Concat images in rows """
column_size = len(images) // row_size
spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4],
dtype=np.uint8) * 255
row_images_with_spacers = []
for row in range(row_size):
row_images = images[column_size*row:column_size*row+column_size]
row_concated_images = concat_images(row_images, image_width, spacer_size)
row_images_with_spacers.append(row_concated_images)
if row != row_size-1:
row_images_with_spacers.append(spacer_h)
ret = np.vstack(row_images_with_spacers)
return ret
def convert_to_colormap(im, cmap):
im = cmap(im)
im = np.uint8(im * 255)
return im
def rgb(im, cmap='jet', smooth=True):
cmap = plt.cm.get_cmap(cmap)
np.seterr(invalid='ignore') # ignore divide by zero err
im = (im - np.min(im)) / (np.max(im) - np.min(im))
if smooth:
im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0)
im = cmap(im)
im = np.uint8(im * 255)
return im
def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16):
images = [rgb(im, cmap, smooth) for im in activations[:n_plots]]
rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1])
return rm_fig
def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None):
'''Compute spatial firing fields'''
if not n_avg:
n_avg = 1000 // options.sequence_length
if not np.any(idxs):
idxs = np.arange(Ng)
idxs = idxs[:Ng]
g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng])
pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2])
activations = np.zeros([Ng, res, res])
counts = np.zeros([res, res])
for index in range(n_avg):
inputs, pos_batch, _ = trajectory_generator.get_test_batch()
g_batch = model.g(inputs).detach().cpu().numpy()
pos_batch = np.reshape(pos_batch.cpu(), [-1, 2])
g_batch = g_batch[:,:,idxs].reshape(-1, Ng)
g[index] = g_batch
pos[index] = pos_batch
x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res
y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res
for i in range(options.batch_size*options.sequence_length):
x = x_batch[i]
y = y_batch[i]
if x >=0 and x < res and y >=0 and y < res:
counts[int(x), int(y)] += 1
activations[:, int(x), int(y)] += g_batch[i, :]
for x in range(res):
for y in range(res):
if counts[x, y] > 0:
activations[:, x, y] /= counts[x, y]
g = g.reshape([-1, Ng])
pos = pos.reshape([-1, 2])
# # scipy binned_statistic_2d is slightly slower
# activations = scipy.stats.binned_statistic_2d(pos[:,0], pos[:,1], g.T, bins=res)[0]
rate_map = activations.reshape(Ng, -1)
return activations, rate_map, g, pos
def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None):
if not n_avg:
n_avg = 1000 // options.sequence_length
activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator,
options, res=res, n_avg=n_avg)
rm_fig = plot_ratemaps(activations, n_plots=len(activations))
imdir = options.save_dir + "/" + options.run_ID
imsave(imdir + "/" + str(step) + ".png", rm_fig)
def save_autocorr(sess, model, save_name, trajectory_generator, step, flags):
starts = [0.2] * 10
ends = np.linspace(0.4, 1.0, num=10)
coord_range=((-1.1, 1.1), (-1.1, 1.1))
masks_parameters = zip(starts, ends.tolist())
latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters)
res = dict()
index_size = 100
for _ in range(index_size):
feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height)
mb_res = sess.run({
'pos_xy': model.target_pos,
'bottleneck': model.g,
}, feed_dict=feed_dict)
res = utils.concat_dict(res, mb_res)
filename = save_name + '/autocorrs_' + str(step) + '.pdf'
imdir = flags.save_dir + '/'
out = utils.get_scores_and_plot(
latest_epoch_scorer, res['pos_xy'], res['bottleneck'],
imdir, filename)