-
Notifications
You must be signed in to change notification settings - Fork 268
/
Copy pathmatch_generator.py
150 lines (130 loc) · 4.41 KB
/
match_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Tuple
from axelrod.random_ import BulkRandomGenerator
@dataclass
class MatchChunk(object):
index_pair: Tuple[int]
match_params: Dict[str, Any]
repetitions: int
seed: BulkRandomGenerator
def as_tuple(self) -> Tuple:
"""Kept for legacy reasons"""
return (self.index_pair, self.match_params, self.repetitions, self.seed)
class MatchGenerator(object):
def __init__(
self,
players,
repetitions,
turns=None,
game=None,
noise=0,
prob_end=None,
edges=None,
match_attributes=None,
seed=None,
):
"""
A class to generate matches. This is used by the Tournament class which
is in charge of playing the matches and collecting the results.
Parameters
----------
players : list
A list of axelrod.Player objects
repetitions : int
The number of repetitions of a given match
turns : integer
The number of turns per match
game : axelrod.Game
The game object used to score the match
noise : float, 0
The probability that a player's intended action should be flipped
prob_end : float
The probability of a given turn ending a match
edges : list
A list of edges between players
match_attributes : dict
Mapping attribute names to values which should be passed to players.
The default is to use the correct values for turns, game and noise
but these can be overridden if desired.
seed : int
"""
self.players = players
self.turns = turns
self.game = game
self.repetitions = repetitions
self.noise = noise
self.opponents = players
self.prob_end = prob_end
self.match_attributes = match_attributes
self.random_generator = BulkRandomGenerator(seed)
self.edges = edges
if edges is not None:
if not graph_is_connected(edges, players):
raise ValueError("The graph edges do not include all players.")
self.size = len(edges)
else:
n = len(self.players)
self.size = int(n * (n - 1) // 2 + n)
def __len__(self):
return self.size
def build_match_chunks(self) -> Iterator[MatchChunk]:
"""
A generator that returns player index pairs and match parameters for a
round robin tournament.
Yields
-------
tuples
((player1 index, player2 index), match object)
"""
if self.edges is None:
edges = complete_graph(self.players)
else:
edges = self.edges
for index_pair in edges:
match_params = self.build_single_match_params()
r = next(self.random_generator)
yield MatchChunk(
index_pair=index_pair,
match_params=match_params,
repetitions=self.repetitions,
seed=r,
)
def build_single_match_params(self):
"""
Creates a single set of match parameters.
"""
return {
"turns": self.turns,
"game": self.game,
"noise": self.noise,
"prob_end": self.prob_end,
"match_attributes": self.match_attributes,
}
def complete_graph(players):
"""
Return generator of edges of a complete graph on a set of players
"""
for player1_index, _ in enumerate(players):
for player2_index in range(player1_index, len(players)):
yield (player1_index, player2_index)
def graph_is_connected(edges, players):
"""
Test if the set of edges defines a graph in which each player is connected
to at least one other player. This function does not test if the graph is
fully connected in the sense that each node is reachable from every other
node.
Parameters:
-----------
edges : a list of 2 tuples
players : a list of player names
Returns:
--------
boolean : True if the graph is connected as specified above.
"""
# Check if all players are connected.
player_indices = set(range(len(players)))
node_indices = set()
for edge in edges:
for node in edge:
node_indices.add(node)
return player_indices == node_indices