diff --git a/consts.py b/consts.py index 27596c1..0b342e1 100644 --- a/consts.py +++ b/consts.py @@ -50,13 +50,9 @@ DISK_FOUR_COUNTS = np.zeros([HEIGHT, WIDTH], int) for row in range(HEIGHT): for column in range(WIDTH): - disk_fours = [] - for four in FOURS: - if four[row, column]: - disk_fours.append(four) + disk_fours = [four for four in FOURS if four[row, column]] DISK_FOURS[row, column] = disk_fours - DISK_FOUR_COUNTS[row, column] = len(disk_fours) - + DISK_FOUR_COUNTS[row, column] = len(DISK_FOURS[row, column]) # Results RED_WIN = 1 DRAW = 0 diff --git a/network.py b/network.py index baa532b..4fa6514 100644 --- a/network.py +++ b/network.py @@ -106,10 +106,10 @@ def variables(self): self.scope + '/') def assign(self, other): - copy_ops = [] - for self_var, other_var in zip(self.variables, other.variables): - copy_ops.append(tf.assign(other_var, self_var)) - return copy_ops + return [ + tf.assign(other_var, self_var) + for self_var, other_var in zip(self.variables, other.variables) + ] class PolicyNetwork(BaseNetwork): diff --git a/policy_training.py b/policy_training.py index 6938252..b8e74fc 100644 --- a/policy_training.py +++ b/policy_training.py @@ -141,8 +141,7 @@ def train_games(self, opponent, games): def process_results(self, opponent, games, step, summary): win_rate = np.mean([game.policy_player_score for game in games]) - average_moves = sum([len(game.moves) - for game in games]) / self.config.batch_size + average_moves = sum(len(game.moves) for game in games) / self.config.batch_size opponent_summary = tf.Summary() opponent_summary.value.add( @@ -263,10 +262,7 @@ def move(self, move, policy_player_turn=False): self.positions.append(self.position) if self.position.gameover(): self.result = self.position.result - if self.result: - self.policy_player_score = float(policy_player_turn) - else: - self.policy_player_score = 0.5 + self.policy_player_score = float(policy_player_turn) if self.result else 0.5 def main(_): diff --git a/util.py b/util.py index 0957892..84b17f5 100644 --- a/util.py +++ b/util.py @@ -10,7 +10,7 @@ def find_previous_run(dir): if os.path.isdir(dir): runs = [child[4:] for child in os.listdir(dir) if child[:4] == 'run_'] if runs: - return max([int(run) for run in runs]) + return max(int(run) for run in runs) return 0