-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathplot_times.py
32 lines (28 loc) · 932 Bytes
/
plot_times.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import matplotlib.pyplot as plt
times = []
gpt_times = []
lda_times = []
lsi_times = []
ctrl_times = []
pplm_times = []
with open("results/timings.txt") as fr:
for line in fr:
items = [float(item) for item in line.split(',')]
times.append(items[0])
gpt_times.append(items[1])
lda_times.append(items[2])
lsi_times.append(items[3])
ctrl_times.append(items[4])
pplm_times.append(items[5])
with plt.style.context('seaborn'):
#fig = plot_figure(style_label=style_label)
plt.plot(times, gpt_times, label='GPT2 (Uncoditional)')
plt.plot(times, lda_times, label='TLG + LDA')
plt.plot(times, lsi_times, label='TLG + LSI')
plt.plot(times, ctrl_times, label='CTRL')
plt.plot(times, pplm_times, label='PPLM')
plt.xlabel("Number of Generated Tokens")
plt.ylabel("Time (seconds)")
plt.legend(loc="upper left")
plt.savefig("times.png")
plt.show()