-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
56 lines (44 loc) · 1.45 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import os
def plot_split_mnist_scores(scores):
# plt.figure()
for i in range(5):
plt.subplot(1, 5, i + 1)
plt.plot(scores[:, i])
plt.ylabel('Classification accuracy')
plt.xlabel('Task number')
plt.ylim([0.4, 1.1])
def plot_average_split_mnist_scores(scores, filename=None):
"""
:param scores:
:return:
"""
for task in range(scores.shape[0]):
plt.subplot(1, 5, task + 1)
score_task = pd.Series(scores[:, task, :].flatten())
taskn = pd.Series(np.matlib.repmat([1, 2, 3, 4, 5], 10, 1).T.flatten())
df = pd.concat([taskn, score_task], axis=1)
df.rename(columns={0: 'Task', 1: 'Score'}, inplace=True)
sns.pointplot(x=df['Task'], y=df['Score'])
plt.title('Task ' + str(task + 1))
plt.ylabel('Accuracy')
plt.ylim([0.4, 1.05])
if filename is not None:
figure_folder = './figs'
plt.savefig(os.path.join(figure_folder, filename))
def show_grid_of_digits(data,nr,nc):
"""
:param data:
:param nr: number of rows
:param nc: number of columns
:return:
"""
for i in range(nr*nc):
plt.subplot(nr,nc,i+1)
im = data[0][i].reshape(28,28)
plt.imshow(im,cmap='gray')
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)