-
Notifications
You must be signed in to change notification settings - Fork 268
/
Copy pathplot.py
361 lines (317 loc) · 11.3 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import pathlib
from typing import Any, Callable, List, Optional, Union
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import tqdm
from numpy import arange, median, nan_to_num
from .load_data_ import axl_filename
from .result_set import ResultSet
titleType = str
namesType = List[str]
dataType = List[List[Union[int, float]]]
class Plot(object):
def __init__(self, result_set: ResultSet) -> None:
self.result_set = result_set
self.num_players = self.result_set.num_players
self.players = self.result_set.players
def _violinplot(
self,
data: dataType,
names: namesType,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:
"""For making violinplots."""
if ax is None:
_, ax = plt.subplots()
else:
ax = ax
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
width = max(self.num_players / 3, 12)
height = width / 2
spacing = 4
positions = spacing * arange(1, self.num_players + 1, 1)
figure.set_size_inches(width, height)
ax.violinplot(
data,
positions=positions,
widths=spacing / 2,
showmedians=True,
showextrema=False,
)
ax.set_xticks(positions)
ax.set_xticklabels(names, rotation=90)
ax.set_xlim((0, spacing * (self.num_players + 1)))
ax.tick_params(axis="both", which="both", labelsize=8)
if title:
ax.set_title(title)
plt.tight_layout()
return figure
# Box and Violin plots for mean score, score differences, wins, and match
# lengths
@property
def _boxplot_dataset(self):
return [
list(nan_to_num(self.result_set.normalised_scores[ir]))
for ir in self.result_set.ranking
]
@property
def _boxplot_xticks_locations(self):
return list(range(1, len(self.result_set.ranked_names) + 2))
@property
def _boxplot_xticks_labels(self):
return [str(n) for n in self.result_set.ranked_names]
def boxplot(
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""For the specific mean score boxplot."""
data = self._boxplot_dataset
names = self._boxplot_xticks_labels
figure = self._violinplot(data, names, title=title, ax=ax)
return figure
@property
def _winplot_dataset(self):
# Sort wins by median
wins = self.result_set.wins
medians = map(median, wins)
medians = sorted(
[(m, i) for (i, m) in enumerate(medians)], reverse=True
)
# Reorder and grab names
wins = [wins[x[-1]] for x in medians]
ranked_names = [str(self.players[x[-1]]) for x in medians]
return wins, ranked_names
def winplot(
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Plots the distributions for the number of wins for each strategy."""
data, names = self._winplot_dataset
figure = self._violinplot(data, names, title=title, ax=ax)
# Expand ylim a bit
maximum = max(max(w) for w in data)
plt.ylim(-0.5, 0.5 + maximum)
return figure
@property
def _sd_ordering(self):
return self.result_set.ranking
@property
def _sdv_plot_dataset(self):
ordering = self._sd_ordering
diffs = [
[score_diff for opponent in player for score_diff in opponent]
for player in self.result_set.score_diffs
]
# Reorder and grab names
diffs = [diffs[i] for i in ordering]
ranked_names = [str(self.players[i]) for i in ordering]
return diffs, ranked_names
def sdvplot(
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Score difference violin plots to visualize the distributions of how
players attain their payoffs."""
diffs, ranked_names = self._sdv_plot_dataset
figure = self._violinplot(diffs, ranked_names, title=title, ax=ax)
return figure
@property
def _lengthplot_dataset(self):
match_lengths = self.result_set.match_lengths
return [
[length for rep in match_lengths for length in rep[playeri]]
for playeri in self.result_set.ranking
]
def lengthplot(
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""For the specific match length boxplot."""
data = self._lengthplot_dataset
names = self._boxplot_xticks_labels
figure = self._violinplot(data, names, title=title, ax=ax)
return figure
@property
def _payoff_dataset(self):
pm = self.result_set.payoff_matrix
return [
[pm[r1][r2] for r2 in self.result_set.ranking]
for r1 in self.result_set.ranking
]
@property
def _pdplot_dataset(self):
# Order like the sdv_plot
ordering = self._sd_ordering
pdm = self.result_set.payoff_diffs_means
# Reorder and grab names
matrix = [[pdm[r1][r2] for r2 in ordering] for r1 in ordering]
players = self.result_set.players
ranked_names = [str(players[i]) for i in ordering]
return matrix, ranked_names
def _payoff_heatmap(
self,
data: dataType,
names: namesType,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
cmap: str = "viridis",
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:
"""Generic heatmap plot"""
if ax is None:
_, ax = plt.subplots()
else:
ax = ax
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
width = max(self.num_players / 4, 12)
height = width
figure.set_size_inches(width, height)
mat = ax.matshow(data, cmap=cmap)
ax.set_xticks(range(self.result_set.num_players))
ax.set_yticks(range(self.result_set.num_players))
ax.set_xticklabels(names, rotation=90)
ax.set_yticklabels(names)
ax.tick_params(axis="both", which="both", labelsize=16)
if title:
ax.set_xlabel(title)
figure.colorbar(mat, ax=ax)
plt.tight_layout()
return figure
def pdplot(
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Payoff difference heatmap to visualize the distributions of how
players attain their payoffs."""
matrix, names = self._pdplot_dataset
return self._payoff_heatmap(matrix, names, title=title, ax=ax)
def payoff(
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Payoff heatmap to visualize the distributions of how
players attain their payoffs."""
data = self._payoff_dataset
names = self.result_set.ranked_names
return self._payoff_heatmap(data, names, title=title, ax=ax)
# Ecological Plot
def stackplot(
self,
eco,
title: Optional[titleType] = None,
logscale: bool = True,
ax: Optional[matplotlib.axes.Axes] = None,
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:
populations = eco.population_sizes
if ax is None:
_, ax = plt.subplots()
else:
ax = ax
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
turns = range(len(populations))
pops = [
[populations[iturn][ir] for iturn in turns]
for ir in self.result_set.ranking
]
ax.stackplot(turns, *pops)
ax.yaxis.tick_left()
ax.yaxis.set_label_position("right")
ax.yaxis.labelpad = 25.0
ax.set_ylim((0.0, 1.0))
ax.set_ylabel("Relative population size")
ax.set_xlabel("Turn")
if title is not None:
ax.set_title(title)
trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
ticks = []
for i, n in enumerate(self.result_set.ranked_names):
x = -0.01
y = (i + 0.5) * 1 / self.result_set.num_players
ax.annotate(
n,
xy=(x, y),
xycoords=trans,
clip_on=False,
va="center",
ha="right",
fontsize=5,
)
ticks.append(y)
ax.set_yticks(ticks)
ax.tick_params(direction="out")
ax.set_yticklabels([])
if logscale:
ax.set_xscale("log")
plt.tight_layout()
return figure
def save_all_plots(
self,
prefix: str = "axelrod",
title_prefix: str = "axelrod",
filetype: str = "svg",
progress_bar: bool = True,
) -> None:
"""
A method to save all plots to file.
Parameters
----------
prefix : str
A prefix for the file name. This can include the directory.
Default: axelrod.
title_prefix : str
A prefix for the title of the plots (appears on the graphic).
Default: axelrod.
filetype : str
A string for the filetype to save files to: pdf, png, svg,
etc...
progress_bar : bool
Whether or not to create a progress bar which will be updated
"""
plots = [
("boxplot", "Payoff"),
("payoff", "Payoff"),
("winplot", "Wins"),
("sdvplot", "Payoff differences"),
("pdplot", "Payoff differences"),
("lengthplot", "Length of Matches"),
]
if progress_bar:
total = len(plots) # Total number of plots
pbar = tqdm.tqdm(total=total, desc="Obtaining plots")
for method, name in plots:
f = getattr(self, method)(
title="{} - {}".format(title_prefix, name)
)
path = pathlib.Path("{}_{}.{}".format(prefix, method, filetype))
f.savefig(axl_filename(path))
plt.close(f)
if progress_bar:
pbar.update()