diff --git a/FAQ.md b/FAQ.md
index 9c965de03..c5b50d5a1 100644
--- a/FAQ.md
+++ b/FAQ.md
@@ -1,4 +1,4 @@
-# Leela Zero常见问题解答 #
+# Leela Zero常见问题解答 #
# Frequently Asked Questions about Leela Zero #
## 为什么网络不是每次都变强的 ##
diff --git a/msvc/VS2017/leela-zero.vcxproj b/msvc/VS2017/leela-zero.vcxproj
index b69f98222..ff03b1f2b 100644
--- a/msvc/VS2017/leela-zero.vcxproj
+++ b/msvc/VS2017/leela-zero.vcxproj
@@ -28,6 +28,7 @@
+
@@ -57,6 +58,7 @@
+
@@ -83,7 +85,7 @@
{7B887BFE-8D2C-46CD-B139-5213434BF218}
Win32Proj
leelazero
- 10.0.16299.0
+ 10.0.17134.0
@@ -224,4 +226,4 @@
-
+
\ No newline at end of file
diff --git a/msvc/VS2017/leela-zero.vcxproj.filters b/msvc/VS2017/leela-zero.vcxproj.filters
index 5ac9f8662..f8d1c8884 100644
--- a/msvc/VS2017/leela-zero.vcxproj.filters
+++ b/msvc/VS2017/leela-zero.vcxproj.filters
@@ -102,6 +102,9 @@
Header Files
+
+ Header Files
+
@@ -182,8 +185,11 @@
Source Files
+
+ Source Files
+
-
+
\ No newline at end of file
diff --git a/src/CPUPipe.cpp b/src/CPUPipe.cpp
index 624180f54..250f6d084 100644
--- a/src/CPUPipe.cpp
+++ b/src/CPUPipe.cpp
@@ -266,29 +266,104 @@ void batchnorm(const size_t channels,
std::vector& data,
const float* const means,
const float* const stddivs,
- const float* const eltwise = nullptr) {
- const auto lambda_ReLU = [](const auto val) { return (val > 0.0f) ?
- val : 0.0f; };
+ const float* const prelu_alphas,
+ const bool relu = true,
+ const float* const eltwise = nullptr)
+{
+ const auto lambda_PReLU = [](const auto val, const auto alpha) { return (val > 0.0f) ?
+ val : alpha * val; };
for (auto c = size_t{0}; c < channels; ++c) {
const auto mean = means[c];
const auto scale_stddiv = stddivs[c];
const auto arr = &data[c * spatial_size];
+ const auto prelu_alpha = prelu_alphas[c];
if (eltwise == nullptr) {
// Classical BN
for (auto b = size_t{0}; b < spatial_size; b++) {
- arr[b] = lambda_ReLU(scale_stddiv * (arr[b] - mean));
+ auto val = scale_stddiv * (arr[b] - mean);
+ if (relu) {
+ val = lambda_PReLU(val, prelu_alpha);
+ }
+ arr[b] = val;
}
} else {
// BN + residual add
const auto res = &eltwise[c * spatial_size];
for (auto b = size_t{0}; b < spatial_size; b++) {
- arr[b] = lambda_ReLU((scale_stddiv * (arr[b] - mean)) + res[b]);
+ auto val = scale_stddiv * (arr[b] - mean) + res[b];
+ if (relu) {
+ val = lambda_PReLU(val, prelu_alpha);
+ }
+ arr[b] = val;
}
}
}
}
+std::vector innerproduct(const size_t inputs,
+ const size_t outputs,
+ const bool ReLU,
+ const std::vector& input,
+ const std::vector& weights,
+ const std::vector& biases) {
+ std::vector output(outputs);
+
+ cblas_sgemv(CblasRowMajor, CblasNoTrans,
+ // M K
+ outputs, inputs,
+ 1.0f, &weights[0], inputs,
+ &input[0], 1,
+ 0.0f, &output[0], 1);
+
+ const auto lambda_ReLU = [](const auto val) { return (val > 0.0f) ?
+ val : 0.0f; };
+ for (unsigned int o = 0; o < outputs; o++) {
+ auto val = biases[o] + output[o];
+ if (ReLU) {
+ val = lambda_ReLU(val);
+ }
+ output[o] = val;
+ }
+
+ return output;
+}
+
+void global_avg_pooling(const size_t channels,
+ const std::vector& input,
+ std::vector& output) {
+
+ for ( auto c = size_t{0}; c < channels; c++) {
+ auto acc = 0.0f;
+ for ( auto i = size_t{0}; i < BOARD_SQUARES; i++) {
+ acc += input[c * BOARD_SQUARES + i];
+ }
+ output[c] = acc/BOARD_SQUARES;
+ }
+}
+
+void apply_se(const size_t channels,
+ const std::vector& input,
+ const std::vector& res,
+ const std::vector& scale,
+ std::vector& output,
+ const std::vector& prelu_alphas) {
+
+ const auto lambda_ReLU = [](const auto val, const auto alpha) { return (val > 0.0f) ?
+ val : alpha * val; };
+
+ const auto lambda_sigmoid = [](const auto val) { return 1.0f/(1.0f + exp(-val)); };
+
+ for ( auto c = size_t{0}; c < channels; c++) {
+ auto sig_scale = lambda_sigmoid(scale[c]);
+ auto alpha = prelu_alphas[c];
+ for ( auto i = size_t{0}; i < BOARD_SQUARES; i++) {
+ output[c * BOARD_SQUARES + i] = lambda_ReLU(sig_scale * input[c * BOARD_SQUARES + i]
+ + res[c * BOARD_SQUARES + i], alpha);
+ }
+ }
+}
+
void CPUPipe::forward(const std::vector& input,
std::vector& output_pol,
std::vector& output_val) {
@@ -309,19 +384,24 @@ void CPUPipe::forward(const std::vector& input,
winograd_convolve3(output_channels, input, m_conv_weights[0], V, M, conv_out);
batchnorm(output_channels, conv_out,
m_batchnorm_means[0].data(),
- m_batchnorm_stddivs[0].data());
+ m_batchnorm_stddivs[0].data(),
+ m_prelu_alphas[0].data());
+
// Residual tower
+ auto pooling = std::vector(output_channels);
auto conv_in = std::vector(output_channels * BOARD_SQUARES);
auto res = std::vector(output_channels * BOARD_SQUARES);
+ auto block = 0;
for (auto i = size_t{1}; i < m_conv_weights.size(); i += 2) {
- auto output_channels = m_input_channels;
+ auto output_channels = m_batchnorm_stddivs[i].size();
std::swap(conv_out, conv_in);
winograd_convolve3(output_channels, conv_in,
m_conv_weights[i], V, M, conv_out);
batchnorm(output_channels, conv_out,
m_batchnorm_means[i].data(),
- m_batchnorm_stddivs[i].data());
+ m_batchnorm_stddivs[i].data(),
+ m_prelu_alphas[i].data());
std::swap(conv_in, res);
std::swap(conv_out, conv_in);
@@ -330,7 +410,20 @@ void CPUPipe::forward(const std::vector& input,
batchnorm(output_channels, conv_out,
m_batchnorm_means[i + 1].data(),
m_batchnorm_stddivs[i + 1].data(),
- res.data());
+ m_prelu_alphas[i + 1].data(),
+ false);
+ std::swap(conv_out, conv_in);
+
+ global_avg_pooling(output_channels, conv_in, pooling);
+
+ auto fc_outputs = m_se_fc1_w[block].size() / output_channels;
+ auto se1 = innerproduct(output_channels, fc_outputs, true, pooling, m_se_fc1_w[block], m_se_fc1_b[block]);
+ auto se2 = innerproduct(fc_outputs, output_channels, false, se1, m_se_fc2_w[block], m_se_fc2_b[block]);
+
+ apply_se(output_channels, conv_in, res, se2, conv_out, m_prelu_alphas[i + 1]);
+
+ block++;
+
}
convolve<1>(Network::OUTPUTS_POLICY, conv_out, m_conv_pol_w, m_conv_pol_b, output_pol);
convolve<1>(Network::OUTPUTS_VALUE, conv_out, m_conv_val_w, m_conv_val_b, output_val);
@@ -342,28 +435,44 @@ void CPUPipe::push_input_convolution(unsigned int /*filter_size*/,
unsigned int /*outputs*/,
const std::vector& weights,
const std::vector& means,
- const std::vector& variances) {
+ const std::vector& variances,
+ const std::vector& prelu_alphas) {
m_conv_weights.push_back(weights);
m_batchnorm_means.push_back(means);
m_batchnorm_stddivs.push_back(variances);
+ m_prelu_alphas.push_back(prelu_alphas);
}
void CPUPipe::push_residual(unsigned int /*filter_size*/,
unsigned int /*channels*/,
unsigned int /*outputs*/,
+ unsigned int /*se_fc_outputs*/,
const std::vector& weights_1,
const std::vector& means_1,
const std::vector& variances_1,
+ const std::vector& prelu_alphas_1,
const std::vector& weights_2,
const std::vector& means_2,
- const std::vector& variances_2) {
+ const std::vector& variances_2,
+ const std::vector& prelu_alphas_2,
+ const std::vector& se_fc1_w,
+ const std::vector& se_fc1_b,
+ const std::vector& se_fc2_w,
+ const std::vector& se_fc2_b) {
m_conv_weights.push_back(weights_1);
m_batchnorm_means.push_back(means_1);
m_batchnorm_stddivs.push_back(variances_1);
+ m_prelu_alphas.push_back(prelu_alphas_1);
m_conv_weights.push_back(weights_2);
m_batchnorm_means.push_back(means_2);
m_batchnorm_stddivs.push_back(variances_2);
+ m_prelu_alphas.push_back(prelu_alphas_2);
+
+ m_se_fc1_w.push_back(se_fc1_w);
+ m_se_fc1_b.push_back(se_fc1_b);
+ m_se_fc2_w.push_back(se_fc2_w);
+ m_se_fc2_b.push_back(se_fc2_b);
}
void CPUPipe::push_convolve(unsigned int filter_size,
diff --git a/src/CPUPipe.h b/src/CPUPipe.h
index 2dd6b552e..3576357f3 100644
--- a/src/CPUPipe.h
+++ b/src/CPUPipe.h
@@ -37,17 +37,25 @@ class CPUPipe : public ForwardPipe {
unsigned int outputs,
const std::vector& weights,
const std::vector& means,
- const std::vector& variances);
+ const std::vector& variances,
+ const std::vector& prelu_alphas);
virtual void push_residual(unsigned int filter_size,
unsigned int channels,
unsigned int outputs,
+ unsigned int se_fc_outputs,
const std::vector& weights_1,
const std::vector& means_1,
const std::vector& variances_1,
+ const std::vector& prelu_alphas_1,
const std::vector& weights_2,
const std::vector& means_2,
- const std::vector& variances_2);
+ const std::vector& variances_2,
+ const std::vector& prelu_alphas_2,
+ const std::vector& se_fc1_w,
+ const std::vector& se_fc1_b,
+ const std::vector& se_fc2_w,
+ const std::vector& se_fc2_b);
virtual void push_convolve(unsigned int filter_size,
unsigned int channels,
@@ -83,6 +91,12 @@ class CPUPipe : public ForwardPipe {
std::vector> m_conv_weights;
std::vector> m_batchnorm_means;
std::vector> m_batchnorm_stddivs;
+ std::vector> m_prelu_alphas;
+
+ std::vector> m_se_fc1_w;
+ std::vector> m_se_fc1_b;
+ std::vector> m_se_fc2_w;
+ std::vector> m_se_fc2_b;
std::vector m_conv_pol_w;
std::vector m_conv_val_w;
diff --git a/src/FastBoard.cpp b/src/FastBoard.cpp
index a9aaff00f..6d5fb9b14 100644
--- a/src/FastBoard.cpp
+++ b/src/FastBoard.cpp
@@ -24,6 +24,7 @@
#include
#include
#include
+#include
#include "Utils.h"
#include "config.h"
@@ -81,6 +82,13 @@ FastBoard::square_t FastBoard::get_square(int vertex) const {
return m_square[vertex];
}
+unsigned short FastBoard::get_liberties(int vertex) const {
+ assert(vertex >= 0 && vertex < MAXSQ);
+ assert(vertex >= 0 && vertex < m_maxsq);
+
+ return m_libs[m_parent[vertex]];
+}
+
void FastBoard::set_square(int vertex, FastBoard::square_t content) {
assert(vertex >= 0 && vertex < MAXSQ);
assert(vertex >= 0 && vertex < m_maxsq);
@@ -93,6 +101,17 @@ FastBoard::square_t FastBoard::get_square(int x, int y) const {
return get_square(get_vertex(x, y));
}
+
+int FastBoard::get_square_neighbor(int vertex, int dir) const {
+ assert(0 <= dir && dir <= 3);
+
+ return vertex + m_dirs[dir];
+}
+
+unsigned short FastBoard::get_liberties(int x, int y) const {
+ return get_liberties(get_vertex(x, y));
+}
+
void FastBoard::set_square(int x, int y, FastBoard::square_t content) {
set_square(get_vertex(x, y), content);
}
@@ -311,6 +330,39 @@ void FastBoard::display_board(int lastmove) {
myprintf("\n");
}
+
+void FastBoard::display_liberties(int lastmove) {
+ int boardsize = get_boardsize();
+
+ myprintf(" ");
+ print_columns();
+ for (int j = boardsize-1; j >= 0; j--) {
+ myprintf("%2d", j+1);
+ if (lastmove == get_vertex(0,j) )
+ myprintf("(");
+ else
+ myprintf(" ");
+ for (int i = 0; i < boardsize; i++) {
+ if (get_square(i,j) != EMPTY) {
+ int libs = get_liberties(i, j);
+ if (libs > 9) { libs = 9; };
+ myprintf("%1d", libs);
+ } else if (starpoint(boardsize, i, j)) {
+ myprintf("+");
+ } else {
+ myprintf(".");
+ }
+ if (lastmove == get_vertex(i, j)) myprintf(")");
+ else if (i != boardsize-1 && lastmove == get_vertex(i, j)+1) myprintf("(");
+ else myprintf(" ");
+ }
+ myprintf("%2d\n", j+1);
+ }
+ myprintf(" ");
+ print_columns();
+ myprintf("\n");
+}
+
void FastBoard::print_columns() {
for (int i = 0; i < get_boardsize(); i++) {
if (i < 25) {
@@ -512,6 +564,13 @@ int FastBoard::get_to_move() const {
return m_tomove;
}
+int FastBoard::get_not_to_move() const {
+ if (black_to_move()) {
+ return WHITE;
+ }
+ return BLACK;
+}
+
bool FastBoard::black_to_move() const {
return m_tomove == BLACK;
}
diff --git a/src/FastBoard.h b/src/FastBoard.h
index 0454957f8..ff5199d27 100644
--- a/src/FastBoard.h
+++ b/src/FastBoard.h
@@ -29,6 +29,7 @@
class FastBoard {
friend class FastState;
+ friend class Ladder;
public:
/*
neighbor counts are up to 4, so 3 bits is ok,
@@ -66,6 +67,8 @@ class FastBoard {
square_t get_square(int x, int y) const;
square_t get_square(int vertex) const ;
int get_vertex(int x, int y) const;
+ unsigned short get_liberties(int x, int y) const;
+ unsigned short get_liberties(int vertex) const ;
void set_square(int x, int y, square_t content);
void set_square(int vertex, square_t content);
std::pair get_xy(int vertex) const;
@@ -80,6 +83,7 @@ class FastBoard {
bool black_to_move() const;
bool white_to_move() const;
int get_to_move() const;
+ int get_not_to_move() const;
void set_to_move(int color);
std::string move_to_text(int move) const;
@@ -89,10 +93,12 @@ class FastBoard {
void reset_board(int size);
void display_board(int lastmove = -1);
+ void display_liberties(int lastmove = -1);
static bool starpoint(int size, int point);
static bool starpoint(int size, int x, int y);
+ int get_square_neighbor(int vertex, int dir) const;
protected:
/*
bit masks to detect eyes on neighbors
diff --git a/src/FastState.cpp b/src/FastState.cpp
index be9189f47..9714845af 100644
--- a/src/FastState.cpp
+++ b/src/FastState.cpp
@@ -27,6 +27,8 @@
#include "Utils.h"
#include "Zobrist.h"
+#include "Ladder.h"
+
using namespace Utils;
void FastState::init_game(int size, float komi) {
@@ -61,7 +63,7 @@ void FastState::reset_board(void) {
board.reset_board(board.get_boardsize());
}
-bool FastState::is_move_legal(int color, int vertex) {
+bool FastState::is_move_legal(int color, int vertex) const {
return vertex == FastBoard::PASS ||
vertex == FastBoard::RESIGN ||
(vertex != m_komove &&
diff --git a/src/FastState.h b/src/FastState.h
index b4e5097b8..880465068 100644
--- a/src/FastState.h
+++ b/src/FastState.h
@@ -33,7 +33,7 @@ class FastState {
void reset_board();
void play_move(int vertex);
- bool is_move_legal(int color, int vertex);
+ bool is_move_legal(int color, int vertex) const;
void set_komi(float komi);
float get_komi() const;
diff --git a/src/ForwardPipe.h b/src/ForwardPipe.h
index b2e042d49..98cefbed8 100644
--- a/src/ForwardPipe.h
+++ b/src/ForwardPipe.h
@@ -37,17 +37,25 @@ class ForwardPipe {
unsigned int outputs,
const std::vector& weights,
const std::vector& means,
- const std::vector& variances) = 0;
+ const std::vector& variances,
+ const std::vector& prelu_alphas) = 0;
virtual void push_residual(unsigned int filter_size,
unsigned int channels,
unsigned int outputs,
+ unsigned int se_fc_outputs,
const std::vector& weights_1,
const std::vector& means_1,
const std::vector& variances_1,
+ const std::vector& prelu_alphas_1,
const std::vector& weights_2,
const std::vector& means_2,
- const std::vector& variances_2) = 0;
+ const std::vector& variances_2,
+ const std::vector& prelu_alphas_2,
+ const std::vector& se_fc1_w,
+ const std::vector& se_fc1_b,
+ const std::vector& se_fc2_w,
+ const std::vector& se_fc2_b) = 0;
virtual void push_convolve(unsigned int filter_size,
unsigned int channels,
diff --git a/src/GTP.cpp b/src/GTP.cpp
index b773f3129..fa6013dfc 100644
--- a/src/GTP.cpp
+++ b/src/GTP.cpp
@@ -156,6 +156,8 @@ const std::string GTP::s_commands[] = {
"play",
"genmove",
"showboard",
+ "showladders",
+ "showliberties",
"undo",
"final_score",
"final_status_list",
@@ -216,6 +218,14 @@ bool GTP::execute(GameState & game, std::string xinput) {
transform_lowercase = false;
}
+ if (xinput.find("add_features") != std::string::npos) {
+ transform_lowercase = false;
+ }
+
+ if (xinput.find("dump_supervised") != std::string::npos) {
+ transform_lowercase = false;
+ }
+
/* eat empty lines, simple preprocessing, lower case */
for (unsigned int tmp = 0; tmp < xinput.size(); tmp++) {
if (xinput[tmp] == 9) {
@@ -501,6 +511,14 @@ bool GTP::execute(GameState & game, std::string xinput) {
gtp_printf(id, "");
game.display_state();
return true;
+ } else if (command.find("showladders") == 0) {
+ gtp_printf(id, "");
+ game.display_ladders();
+ return true;
+ } else if (command.find("showliberties") == 0) {
+ gtp_printf(id, "");
+ game.display_liberties();
+ return true;
} else if (command.find("final_score") == 0) {
float ftmp = game.final_score();
/* white wins */
@@ -883,6 +901,22 @@ bool GTP::execute(GameState & game, std::string xinput) {
gtp_fail_printf(id, "syntax not understood");
}
+ return true;
+ } else if (command.find("add_features") == 0) {
+ std::istringstream cmdstream(command);
+ std::string tmp, sgfname, outname;
+
+ // tmp will eat add_features
+ cmdstream >> tmp >> sgfname >> outname;
+
+ Training::add_features(sgfname, outname);
+
+ if (!cmdstream.fail()) {
+ gtp_printf(id, "");
+ } else {
+ gtp_fail_printf(id, "syntax not understood");
+ }
+
return true;
}
diff --git a/src/GameState.cpp b/src/GameState.cpp
index 9d94927d8..e78a9d049 100644
--- a/src/GameState.cpp
+++ b/src/GameState.cpp
@@ -33,6 +33,7 @@
#include "FullBoard.h"
#include "KoState.h"
#include "UCTSearch.h"
+#include "Ladder.h"
void GameState::init_game(int size, float komi) {
KoState::init_game(size, komi);
@@ -169,6 +170,14 @@ void GameState::display_state() {
m_timecontrol.display_times();
}
+void GameState::display_ladders() {
+ Ladder::display_ladders(*this);
+}
+
+void GameState::display_liberties() {
+ board.display_liberties();
+}
+
int GameState::who_resigned() const {
return m_resigned;
}
diff --git a/src/GameState.h b/src/GameState.h
index a56757504..9f58fc6b1 100644
--- a/src/GameState.h
+++ b/src/GameState.h
@@ -63,6 +63,8 @@ class GameState : public KoState {
void adjust_time(int color, int time, int stones);
void display_state();
+ void display_ladders();
+ void display_liberties();
bool has_resigned() const;
int who_resigned() const;
diff --git a/src/Ladder.cpp b/src/Ladder.cpp
new file mode 100644
index 000000000..16bf78fbe
--- /dev/null
+++ b/src/Ladder.cpp
@@ -0,0 +1,228 @@
+/*
+ This file is part of Leela Zero.
+ Copyright (C) 2018 Henrik Forsten
+
+ Leela Zero is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Zero is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Zero. If not, see .
+*/
+
+#include
+#include
+#include
+
+#include "config.h"
+#include "Ladder.h"
+
+#include "Utils.h"
+
+using namespace Utils;
+
+
+Ladder::LadderStatus Ladder::ladder_status(const FastState &state) {
+ const auto board = state.board;
+
+ Ladder::LadderStatus status;
+
+ for (auto i = 0; i < BOARD_SIZE; i++) {
+ for (auto j = 0; j < BOARD_SIZE; j++) {
+ auto vertex = board.get_vertex(i, j);
+ status[i][j] = NO_LADDER;
+ if (ladder_capture(state, vertex)) {
+ status[i][j] = CAPTURE;
+ }
+ if (ladder_escape(state, vertex)) {
+ status[i][j] = ESCAPE;
+ }
+ }
+ }
+ return status;
+}
+
+bool Ladder::ladder_capture(const FastState &state, int vertex, int group, int depth) {
+
+ const auto &board = state.board;
+ const auto capture_player = board.get_to_move();
+ const auto escape_player = board.get_not_to_move();
+
+ if (!state.is_move_legal(capture_player, vertex)) {
+ return false;
+ }
+
+ // Assume that capture succeeds if it takes this long
+ if (depth >= 100) {
+ return true;
+ }
+
+ std::vector groups_in_ladder;
+
+ if (group == FastBoard::PASS) {
+ // Check if there are nearby groups with 2 liberties
+ for (int d = 0; d < 4; d++) {
+ int n_vtx = board.get_square_neighbor(vertex, d);
+ int n = board.get_square(n_vtx);
+ if ((n == escape_player) && (board.get_liberties(n_vtx) == 2)) {
+ auto parent = board.m_parent[n_vtx];
+ if (std::find(groups_in_ladder.begin(), groups_in_ladder.end(), parent) == groups_in_ladder.end()) {
+ groups_in_ladder.emplace_back(parent);
+ }
+ }
+ }
+ } else {
+ groups_in_ladder.emplace_back(group);
+ }
+
+ for (auto& group : groups_in_ladder) {
+ auto state_copy = std::make_unique(state);
+ auto &board_copy = state_copy->board;
+
+ state_copy->play_move(vertex);
+
+ int escape = FastBoard::PASS;
+ int newpos = group;
+ do {
+ for (int d = 0; d < 4; d++) {
+ int stone = newpos + board_copy.m_dirs[d];
+ // If the surrounding stones are in atari capture fails
+ if (board_copy.m_square[stone] == capture_player) {
+ if (board_copy.get_liberties(stone) == 1) {
+ return false;
+ }
+ }
+ // Possible move to escape
+ if (board_copy.m_square[stone] == FastBoard::EMPTY) {
+ escape = stone;
+ }
+ }
+ newpos = board_copy.m_next[newpos];
+ } while (newpos != group);
+
+ assert(escape != FastBoard::PASS);
+
+ // If escaping fails the capture was successful
+ if (!ladder_escape(*state_copy, escape, group, depth + 1)) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool Ladder::ladder_escape(const FastState &state, const int vertex, int group, int depth) {
+ const auto &board = state.board;
+ const auto escape_player = board.get_to_move();
+
+ if (!state.is_move_legal(escape_player, vertex)) {
+ return false;
+ }
+
+ // Assume that escaping failed if it takes this long
+ if (depth >= 100) {
+ return false;
+ }
+
+ std::vector groups_in_ladder;
+
+ if (group == FastBoard::PASS) {
+ // Check if there are nearby groups with 1 liberties
+ for (int d = 0; d < 4; d++) {
+ int n_vtx = board.get_square_neighbor(vertex, d);
+ int n = board.get_square(n_vtx);
+ if ((n == escape_player) && (board.get_liberties(n_vtx) == 1)) {
+ auto parent = board.m_parent[n_vtx];
+ if (std::find(groups_in_ladder.begin(), groups_in_ladder.end(), parent) == groups_in_ladder.end()) {
+ groups_in_ladder.emplace_back(parent);
+ }
+ }
+ }
+ } else {
+ groups_in_ladder.emplace_back(group);
+ }
+
+ for (auto& group : groups_in_ladder) {
+ auto state_copy = std::make_unique(state);
+ auto &board_copy = state_copy->board;
+
+ state_copy->play_move(vertex);
+
+ if (board_copy.get_liberties(group) >= 3) {
+ // Opponent can't atari on the next turn
+ return true;
+ }
+
+ if (board_copy.get_liberties(group) == 1) {
+ // Will get captured on the next turn
+ return false;
+ }
+
+ // Still two liberties left, check for possible captures
+ int newpos = group;
+ do {
+ for (int d = 0; d < 4; d++) {
+ int empty = newpos + board_copy.m_dirs[d];
+ if (board_copy.m_square[empty] == FastBoard::EMPTY) {
+ if (ladder_capture(*state_copy, empty, group, depth + 1)) {
+ // Got captured
+ return false;
+ }
+ }
+ }
+ newpos = board_copy.m_next[newpos];
+ } while (newpos != group);
+
+ // Ladder capture failed, escape succeeded
+ return true;
+ }
+
+ return false;
+}
+
+static void print_columns() {
+ for (int i = 0; i < BOARD_SIZE; i++) {
+ if (i < 25) {
+ myprintf("%c ", (('a' + i < 'i') ? 'a' + i : 'a' + i + 1));
+ }
+ else {
+ myprintf("%c ", (('A' + (i - 25) < 'I') ? 'A' + (i - 25) : 'A' + (i - 25) + 1));
+ }
+ }
+ myprintf("\n");
+}
+
+void Ladder::display_ladders(const LadderStatus &status) {
+ myprintf("\n ");
+ print_columns();
+ for (int j = BOARD_SIZE-1; j >= 0; j--) {
+ myprintf("%2d", j+1);
+ myprintf(" ");
+ for (int i = 0; i < BOARD_SIZE; i++) {
+ if (status[i][j] == CAPTURE) {
+ myprintf("C");
+ } else if (status[i][j] == ESCAPE) {
+ myprintf("E");
+ } else if (FastBoard::starpoint(BOARD_SIZE, i, j)) {
+ myprintf("+");
+ } else {
+ myprintf(".");
+ }
+ myprintf(" ");
+ }
+ myprintf("%2d\n", j+1);
+ }
+ myprintf(" ");
+ print_columns();
+ myprintf("\n");
+}
+
+void Ladder::display_ladders(const FastState &state) {
+ display_ladders(ladder_status(state));
+}
diff --git a/src/Ladder.h b/src/Ladder.h
new file mode 100644
index 000000000..38f07a66d
--- /dev/null
+++ b/src/Ladder.h
@@ -0,0 +1,45 @@
+/*
+ This file is part of Leela Zero.
+ Copyright (C) 2018 Henrik Forsten
+
+ Leela Zero is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Zero is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Zero. If not, see .
+*/
+
+#ifndef LADDER_H_INCLUDED
+#define LADDER_H_INCLUDED
+
+#include "FastState.h"
+
+void ladder_captures(const FastState &state);
+bool ladder_capture(const FastState &state, const int vertex);
+
+class Ladder {
+
+ enum ladder_status_t : char {
+ NO_LADDER = 0, CAPTURE = 1, ESCAPE = 2
+ };
+
+ using LadderStatus = std::array, BOARD_SIZE>;
+
+ public:
+ static LadderStatus ladder_status(const FastState &state);
+
+ static bool ladder_capture(const FastState &state, int vertex, int group=FastBoard::PASS, int depth=0);
+ static bool ladder_escape(const FastState &state, int vertex, int group=FastBoard::PASS, int depth=0);
+
+ static void display_ladders(const LadderStatus &status);
+ static void display_ladders(const FastState &state);
+};
+
+#endif
diff --git a/src/Makefile b/src/Makefile
index 9cf62a9ef..c15402b44 100644
--- a/src/Makefile
+++ b/src/Makefile
@@ -50,7 +50,8 @@ sources = Network.cpp FullBoard.cpp KoState.cpp Training.cpp \
SGFParser.cpp Timing.cpp Utils.cpp FastBoard.cpp \
SGFTree.cpp Zobrist.cpp FastState.cpp GTP.cpp Random.cpp \
SMP.cpp UCTNode.cpp UCTNodePointer.cpp UCTNodeRoot.cpp \
- OpenCL.cpp OpenCLScheduler.cpp NNCache.cpp Tuner.cpp CPUPipe.cpp
+ OpenCL.cpp OpenCLScheduler.cpp NNCache.cpp Tuner.cpp CPUPipe.cpp \
+ Ladder.cpp
objects = $(sources:.cpp=.o)
deps = $(sources:%.cpp=%.d)
diff --git a/src/Network.cpp b/src/Network.cpp
index b979d1f64..573206d7d 100644
--- a/src/Network.cpp
+++ b/src/Network.cpp
@@ -57,6 +57,7 @@
#include "Random.h"
#include "ThreadPool.h"
#include "Timing.h"
+#include "Ladder.h"
#include "Utils.h"
namespace x3 = boost::spirit::x3;
@@ -198,14 +199,14 @@ std::pair Network::load_v1_network(std::istream& wtfile) {
}
linecount++;
}
- // 1 format id, 1 input layer (4 x weights), 14 ending weights,
- // the rest are residuals, every residual has 8 x weight lines
- auto residual_blocks = linecount - (1 + 4 + 14);
- if (residual_blocks % 8 != 0) {
+ // 1 format id, 1 input layer (6 x weights), 16 ending weights,
+ // the rest are residuals, every residual has 16 x weight lines
+ auto residual_blocks = linecount - (1 + 6 + 16);
+ if (residual_blocks % 16 != 0) {
myprintf("\nInconsistent number of weights in the file.\n");
return {0, 0};
}
- residual_blocks /= 8;
+ residual_blocks /= 16;
myprintf("%d blocks.\n", residual_blocks);
// Re-read file and process
@@ -215,8 +216,8 @@ std::pair Network::load_v1_network(std::istream& wtfile) {
// Get the file format id out of the way
std::getline(wtfile, line);
- const auto plain_conv_layers = 1 + (residual_blocks * 2);
- const auto plain_conv_wts = plain_conv_layers * 4;
+ const auto input_weights = 6;
+ const auto plain_conv_wts = input_weights + residual_blocks * 16;
linecount = 0;
while (std::getline(wtfile, line)) {
std::vector weights;
@@ -228,41 +229,124 @@ std::pair Network::load_v1_network(std::istream& wtfile) {
linecount + 2); //+1 from version line, +1 from 0-indexing
return {0,0};
}
- if (linecount < plain_conv_wts) {
- if (linecount % 4 == 0) {
+ if (linecount < input_weights) {
+ if (linecount % 6 == 0) {
m_conv_weights.emplace_back(weights);
- } else if (linecount % 4 == 1) {
- // Redundant in our model, but they encode the
- // number of outputs so we have to read them in.
- m_conv_biases.emplace_back(weights);
- } else if (linecount % 4 == 2) {
+ } else if (linecount % 6 == 1) {
+ m_batchnorm_gammas.emplace_back(weights);
+ } else if (linecount % 6 == 2) {
+ m_batchnorm_betas.emplace_back(weights);
+ } else if (linecount % 6 == 3) {
m_batchnorm_means.emplace_back(weights);
- } else if (linecount % 4 == 3) {
+ } else if (linecount % 6 == 4) {
process_bn_var(weights);
- m_batchnorm_stddevs.emplace_back(weights);
+ m_batchnorm_stddivs.emplace_back(weights);
+ } else if (linecount % 6 == 5) {
+ m_prelu_alphas.emplace_back(weights);
}
- } else {
- switch (linecount - plain_conv_wts) {
- case 0: m_conv_pol_w = std::move(weights); break;
- case 1: m_conv_pol_b = std::move(weights); break;
- case 2: std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w1)); break;
- case 3: std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w2)); break;
- case 4: std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_w)); break;
- case 5: std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_b)); break;
- case 6: m_conv_val_w = std::move(weights); break;
- case 7: m_conv_val_b = std::move(weights); break;
- case 8: std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w1)); break;
- case 9: std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w2)); break;
- case 10: std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_w)); break;
- case 11: std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_b)); break;
- case 12: std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_w)); break;
- case 13: std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_b)); break;
+ } else if (linecount < plain_conv_wts) {
+ switch ((linecount - input_weights) % 16) {
+ case 0:
+ assert(weights.size() == size_t{channels * channels * 9});
+ m_conv_weights.emplace_back(weights);
+ break;
+ case 1:
+ assert(weights.size() == size_t{channels});
+ m_batchnorm_gammas.emplace_back(weights);
+ break;
+ case 2:
+ assert(weights.size() == size_t{channels});
+ m_batchnorm_betas.emplace_back(weights);
+ break;
+ case 3:
+ assert(weights.size() == size_t{channels});
+ m_batchnorm_means.emplace_back(weights);
+ break;
+ case 4:
+ assert(weights.size() == size_t{channels});
+ process_bn_var(weights);
+ m_batchnorm_stddivs.emplace_back(weights);
+ break;
+ case 5:
+ assert(weights.size() == size_t{channels});
+ m_prelu_alphas.emplace_back(weights);
+ break;
+ case 6:
+ assert(weights.size() == size_t{channels * channels * 9});
+ m_conv_weights.emplace_back(weights);
+ break;
+ case 7:
+ assert(weights.size() == size_t{channels});
+ m_batchnorm_gammas.emplace_back(weights);
+ break;
+ case 8:
+ assert(weights.size() == size_t{channels});
+ m_batchnorm_betas.emplace_back(weights);
+ break;
+ case 9:
+ assert(weights.size() == size_t{channels});
+ m_batchnorm_means.emplace_back(weights);
+ break;
+ case 10:
+ assert(weights.size() == size_t{channels});
+ process_bn_var(weights);
+ m_batchnorm_stddivs.emplace_back(weights);
+ break;
+ case 11:
+ m_se_fc1_w.emplace_back(weights);
+ break;
+ case 12:
+ m_se_fc1_b.emplace_back(weights);
+ break;
+ case 13:
+ m_se_fc2_w.emplace_back(weights);
+ break;
+ case 14:
+ assert(weights.size() == size_t{channels});
+ m_se_fc2_b.emplace_back(weights);
+ break;
+ case 15:
+ assert(weights.size() == size_t{channels});
+ m_prelu_alphas.emplace_back(weights);
+ break;
}
+ } else if (linecount == plain_conv_wts) {
+ m_conv_pol_w = std::move(weights);
+ } else if (linecount == plain_conv_wts + 1) {
+ m_conv_pol_b = std::move(weights);
+ } else if (linecount == plain_conv_wts + 2) {
+ std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w1));
+ } else if (linecount == plain_conv_wts + 3) {
+ process_bn_var(weights);
+ std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w2));
+ } else if (linecount == plain_conv_wts + 4) {
+ std::copy(cbegin(weights), cend(weights), begin(m_prelu_pol_alpha));
+ } else if (linecount == plain_conv_wts + 5) {
+ std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_w));
+ } else if (linecount == plain_conv_wts + 6) {
+ std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_b));
+ } else if (linecount == plain_conv_wts + 7) {
+ m_conv_val_w = std::move(weights);
+ } else if (linecount == plain_conv_wts + 8) {
+ m_conv_val_b = std::move(weights);
+ } else if (linecount == plain_conv_wts + 9) {
+ std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w1));
+ } else if (linecount == plain_conv_wts + 10) {
+ process_bn_var(weights);
+ std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w2));
+ } else if (linecount == plain_conv_wts + 11) {
+ std::copy(cbegin(weights), cend(weights), begin(m_prelu_val_alpha));
+ } else if (linecount == plain_conv_wts + 12) {
+ std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_w));
+ } else if (linecount == plain_conv_wts + 13) {
+ std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_b));
+ } else if (linecount == plain_conv_wts + 14) {
+ std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_w));
+ } else if (linecount == plain_conv_wts + 15) {
+ std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_b));
}
linecount++;
}
- process_bn_var(m_bn_pol_w2);
- process_bn_var(m_bn_val_w2);
return {channels, static_cast(residual_blocks)};
}
@@ -299,8 +383,11 @@ std::pair Network::load_network_file(const std::string& filename) {
auto iss = std::stringstream{line};
// First line is the file format version id
iss >> format_version;
- if (iss.fail() || (format_version != 1 && format_version != 2)) {
+ if (iss.fail() || (format_version != 502)) {
myprintf("Weights file is the wrong version.\n");
+ if (format_version == 1 || format_version == 2) {
+ myprintf("Old weights are not supported at the moment.");
+ }
return {0, 0};
} else {
// Version 2 networks are identical to v1, except
@@ -351,24 +438,21 @@ void Network::initialize(int playouts, const std::string & weightsfile) {
weight_index++;
}
- // Biases are not calculated and are typically zero but some networks might
- // still have non-zero biases.
- // Move biases to batchnorm means to make the output match without having
- // to separately add the biases.
- for (auto i = size_t{0}; i < m_conv_biases.size(); i++) {
- for (auto j = size_t{0}; j < m_batchnorm_means[i].size(); j++) {
- m_batchnorm_means[i][j] -= m_conv_biases[i][j];
- m_conv_biases[i][j] = 0.0f;
+ // Move betas to batchnorm means
+ for (auto i = size_t{0}; i < m_batchnorm_betas.size(); i++) {
+ for (auto j = size_t{0}; j < m_batchnorm_betas[i].size(); j++) {
+ m_batchnorm_stddivs[i][j] *= m_batchnorm_gammas[i][j];
+ m_batchnorm_means[i][j] -= m_batchnorm_betas[i][j] / m_batchnorm_stddivs[i][j];
}
}
for (auto i = size_t{0}; i < m_bn_val_w1.size(); i++) {
- m_bn_val_w1[i] -= m_conv_val_b[i];
+ m_bn_val_w1[i] -= m_conv_val_b[i] / m_bn_val_w2[i];
m_conv_val_b[i] = 0.0f;
}
for (auto i = size_t{0}; i < m_bn_pol_w1.size(); i++) {
- m_bn_pol_w1[i] -= m_conv_pol_b[i];
+ m_bn_pol_w1[i] -= m_conv_pol_b[i] / m_bn_pol_w2[i];
m_conv_pol_b[i] = 0.0f;
}
@@ -434,18 +518,28 @@ void Network::initialize(int playouts, const std::string & weightsfile) {
// Winograd filter transformation changes filter size to 4x4
p->push_input_convolution(WINOGRAD_ALPHA, INPUT_CHANNELS,
channels, m_conv_weights[weight_index],
- m_batchnorm_means[weight_index], m_batchnorm_stddevs[weight_index]);
+ m_batchnorm_means[weight_index], m_batchnorm_stddivs[weight_index],
+ m_prelu_alphas[weight_index]);
weight_index++;
// residual blocks
for (auto i = size_t{0}; i < residual_blocks; i++) {
+ auto fc_outputs = m_se_fc1_w[i].size() / channels;
+
p->push_residual(WINOGRAD_ALPHA, channels, channels,
- m_conv_weights[weight_index],
- m_batchnorm_means[weight_index],
- m_batchnorm_stddevs[weight_index],
- m_conv_weights[weight_index + 1],
- m_batchnorm_means[weight_index + 1],
- m_batchnorm_stddevs[weight_index + 1]);
+ fc_outputs,
+ m_conv_weights[weight_index],
+ m_batchnorm_means[weight_index],
+ m_batchnorm_stddivs[weight_index],
+ m_prelu_alphas[weight_index],
+ m_conv_weights[weight_index + 1],
+ m_batchnorm_means[weight_index + 1],
+ m_batchnorm_stddivs[weight_index + 1],
+ m_prelu_alphas[weight_index + 1],
+ m_se_fc1_w[i],
+ m_se_fc1_b[i],
+ m_se_fc2_w[i],
+ m_se_fc2_b[i]);
weight_index += 2;
}
@@ -521,24 +615,36 @@ void batchnorm(const size_t channels,
std::vector& data,
const float* const means,
const float* const stddivs,
- const float* const eltwise = nullptr) {
- const auto lambda_ReLU = [](const auto val) { return (val > 0.0f) ?
- val : 0.0f; };
+ const float* const prelu_alphas,
+ const bool relu = true,
+ const float* const eltwise = nullptr)
+{
+ const auto lambda_PReLU = [](const auto val, const auto alpha) { return (val > 0.0f) ?
+ val : alpha * val; };
for (auto c = size_t{0}; c < channels; ++c) {
const auto mean = means[c];
const auto scale_stddiv = stddivs[c];
const auto arr = &data[c * spatial_size];
+ const auto prelu_alpha = prelu_alphas[c];
if (eltwise == nullptr) {
// Classical BN
for (auto b = size_t{0}; b < spatial_size; b++) {
- arr[b] = lambda_ReLU(scale_stddiv * (arr[b] - mean));
+ auto val = scale_stddiv * (arr[b] - mean);
+ if (relu) {
+ val = lambda_PReLU(val, prelu_alpha);
+ }
+ arr[b] = val;
}
} else {
// BN + residual add
const auto res = &eltwise[c * spatial_size];
for (auto b = size_t{0}; b < spatial_size; b++) {
- arr[b] = lambda_ReLU((scale_stddiv * (arr[b] - mean)) + res[b]);
+ auto val = scale_stddiv * (arr[b] - mean) + res[b];
+ if (relu) {
+ val = lambda_PReLU(val, prelu_alpha);
+ }
+ arr[b] = val;
}
}
}
@@ -748,7 +854,7 @@ Network::Netresult Network::get_output_internal(
// Get the moves
batchnorm(OUTPUTS_POLICY, policy_data,
- m_bn_pol_w1.data(), m_bn_pol_w2.data());
+ m_bn_pol_w1.data(), m_bn_pol_w2.data(), m_prelu_pol_alpha.data());
const auto policy_out =
innerproduct(
policy_data, m_ip_pol_w, m_ip_pol_b);
@@ -756,7 +862,7 @@ Network::Netresult Network::get_output_internal(
// Now get the value
batchnorm(OUTPUTS_VALUE, value_data,
- m_bn_val_w1.data(), m_bn_val_w2.data());
+ m_bn_val_w1.data(), m_bn_val_w2.data(), m_prelu_val_alpha.data());
const auto winrate_data =
innerproduct(value_data, m_ip1_val_w, m_ip1_val_b);
const auto winrate_out =
@@ -850,6 +956,67 @@ void Network::fill_input_plane_pair(const FullBoard& board,
}
}
+void Network::legal_plane(const GameState* const state,
+ std::vector::iterator legal,
+ const int symmetry) {
+ const auto to_move = state->board.get_to_move();
+ for (auto idx = 0; idx < BOARD_SQUARES; idx++) {
+ const auto sym_idx = symmetry_nn_idx_table[symmetry][idx];
+ const auto x = sym_idx % BOARD_SIZE;
+ const auto y = sym_idx / BOARD_SIZE;
+ const auto vtx = state->board.get_vertex(x, y);
+ const auto color = state->board.get_square(x, y);
+ if (color == FastBoard::EMPTY) {
+ if (!state->is_move_legal(to_move, vtx)) {
+ legal[idx] = true;
+ }
+ }
+ }
+}
+
+void Network::fill_liberty_planes(const FullBoard& board,
+ std::vector::iterator planes_black,
+ std::vector::iterator planes_white,
+ const int plane_count,
+ const int symmetry) {
+ for (auto idx = 0; idx < BOARD_SQUARES; idx++) {
+ const auto sym_idx = symmetry_nn_idx_table[symmetry][idx];
+ const auto x = sym_idx % BOARD_SIZE;
+ const auto y = sym_idx / BOARD_SIZE;
+ const auto vtx = board.get_vertex(x, y);
+ const auto color = board.get_square(vtx);
+ if (color != FastBoard::EMPTY) {
+ auto libs = board.get_liberties(x, y);
+ if (libs > plane_count) {
+ libs = plane_count;
+ }
+ if (color == FastBoard::BLACK) {
+ planes_black[(libs-1) * BOARD_SQUARES + idx] = true;
+ } else {
+ planes_white[(libs-1) * BOARD_SQUARES + idx] = true;
+ }
+ }
+ }
+}
+
+void Network::fill_ladder_planes(const GameState* const state,
+ std::vector::iterator captures,
+ std::vector::iterator escapes,
+ const int symmetry) {
+ for (auto idx = 0; idx < BOARD_SQUARES; idx++) {
+ const auto sym_idx = symmetry_nn_idx_table[symmetry][idx];
+ const auto x = sym_idx % BOARD_SIZE;
+ const auto y = sym_idx / BOARD_SIZE;
+ const auto vtx = state->board.get_vertex(x, y);
+ if (Ladder::ladder_capture(*state, vtx)) {
+ captures[idx] = true;
+ }
+ if (Ladder::ladder_escape(*state, vtx)) {
+ escapes[idx] = true;
+ }
+ }
+}
+
std::vector Network::gather_features(const GameState* const state,
const int symmetry) {
assert(symmetry >= 0 && symmetry < NUM_SYMMETRIES);
@@ -864,9 +1031,25 @@ std::vector Network::gather_features(const GameState* const state,
const auto white_it = blacks_move ?
begin(input_data) + INPUT_MOVES * BOARD_SQUARES :
begin(input_data);
+
+ const auto legal_it = begin(input_data) + (2 * INPUT_MOVES) * BOARD_SQUARES;
+
+ const auto liberties_our = begin(input_data) + (2 * INPUT_MOVES + 1) * BOARD_SQUARES;
+ const auto liberties_other = begin(input_data) + \
+ (2 * INPUT_MOVES + 1 + LIBERTY_PLANES) * BOARD_SQUARES;
+
+ const auto liberties_black_it = blacks_move ? liberties_our : liberties_other;
+ const auto liberties_white_it = blacks_move ? liberties_other : liberties_our;
+
+ const auto captures_it = begin(input_data) + \
+ (2 * INPUT_MOVES + 1 + 2 * LIBERTY_PLANES) * BOARD_SQUARES;
+
+ const auto escapes_it = begin(input_data) + \
+ (2 * INPUT_MOVES + 1 + 2 * LIBERTY_PLANES + 1) * BOARD_SQUARES;
+
const auto to_move_it = blacks_move ?
- begin(input_data) + 2 * INPUT_MOVES * BOARD_SQUARES :
- begin(input_data) + (2 * INPUT_MOVES + 1) * BOARD_SQUARES;
+ begin(input_data) + (2 * INPUT_MOVES + 1 + 2 * LIBERTY_PLANES + 2) * BOARD_SQUARES :
+ begin(input_data) + (2 * INPUT_MOVES + 1 + 2 * LIBERTY_PLANES + 3) * BOARD_SQUARES;
const auto moves = std::min(state->get_movenum() + 1, INPUT_MOVES);
// Go back in time, fill history boards
@@ -880,6 +1063,15 @@ std::vector Network::gather_features(const GameState* const state,
std::fill(to_move_it, to_move_it + BOARD_SQUARES, float(true));
+ legal_plane(state, legal_it, symmetry);
+
+ fill_liberty_planes(state->board,
+ liberties_black_it,
+ liberties_white_it,
+ LIBERTY_PLANES, symmetry);
+
+ fill_ladder_planes(state, captures_it, escapes_it, symmetry);
+
return input_data;
}
diff --git a/src/Network.h b/src/Network.h
index a62d222de..0da3e90bd 100644
--- a/src/Network.h
+++ b/src/Network.h
@@ -67,8 +67,16 @@ class Network {
const int symmetry = -1,
const bool skip_cache = false);
+ // File format version
static constexpr auto INPUT_MOVES = 8;
- static constexpr auto INPUT_CHANNELS = 2 * INPUT_MOVES + 2;
+ static constexpr auto LIBERTY_PLANES = 4;
+
+ // History 2 * INPUT_MOVES
+ // Legal 1
+ // Liberties us 2 * LIBERTY_PLANES
+ // Ladder capture/escape 2
+ // Black/white to play 2
+ static constexpr auto INPUT_CHANNELS = 2 * INPUT_MOVES + 1 + 2 * LIBERTY_PLANES + 2 + 2;
static constexpr auto OUTPUTS_POLICY = 2;
static constexpr auto OUTPUTS_VALUE = 1;
@@ -115,6 +123,18 @@ class Network {
std::vector::iterator black,
std::vector::iterator white,
const int symmetry);
+ static void legal_plane(const GameState* const state,
+ std::vector::iterator legal,
+ const int symmetry);
+ static void fill_liberty_planes(const FullBoard& board,
+ std::vector::iterator planes_black,
+ std::vector::iterator planes_white,
+ const int plane_count,
+ const int symmetry);
+ static void fill_ladder_planes(const GameState* const state,
+ std::vector::iterator captures,
+ std::vector::iterator escapes,
+ const int symmetry);
bool probe_cache(const GameState* const state, Network::Netresult& result);
std::unique_ptr m_forward;
#ifdef USE_OPENCL_SELFCHECK
@@ -132,15 +152,23 @@ class Network {
// Input + residual block tower
std::vector> m_conv_weights;
- std::vector> m_conv_biases;
+ std::vector> m_batchnorm_betas;
+ std::vector> m_batchnorm_gammas;
std::vector> m_batchnorm_means;
- std::vector> m_batchnorm_stddevs;
+ std::vector> m_batchnorm_stddivs;
+ std::vector> m_prelu_alphas;
+
+ std::vector> m_se_fc1_w;
+ std::vector> m_se_fc1_b;
+ std::vector> m_se_fc2_w;
+ std::vector> m_se_fc2_b;
// Policy head
std::vector m_conv_pol_w;
std::vector m_conv_pol_b;
std::array m_bn_pol_w1;
std::array m_bn_pol_w2;
+ std::array m_prelu_pol_alpha;
std::array m_ip_pol_w;
std::array m_ip_pol_b;
@@ -150,6 +178,7 @@ class Network {
std::vector m_conv_val_b;
std::array m_bn_val_w1;
std::array m_bn_val_w2;
+ std::array m_prelu_val_alpha;
std::array m_ip1_val_w;
std::array m_ip1_val_b;
diff --git a/src/OpenCL.cpp b/src/OpenCL.cpp
index f22278a29..6f00f40f8 100644
--- a/src/OpenCL.cpp
+++ b/src/OpenCL.cpp
@@ -74,6 +74,14 @@ static const std::string sourceCode_convolve3 =
#include "kernels/convolve3.opencl"
;
+static std::string sourceCode_global_avg_pooling =
+ #include "kernels/pooling.opencl"
+;
+
+static std::string sourceCode_apply_se =
+ #include "kernels/apply_se.opencl"
+;
+
const std::string sourceCode_sgemm =
#include "kernels/clblast/xgemm_part1.opencl"
#include "kernels/clblast/xgemm_part2.opencl"
@@ -81,6 +89,11 @@ const std::string sourceCode_sgemm =
#include "kernels/clblast/xgemm_batched.opencl"
;
+const std::string sourceCode_sgemv =
+ #include "kernels/clblast/xgemv.opencl"
+;
+
+
template
void OpenCL::ensure_context_initialized(OpenCLContext &opencl_context) {
if (!opencl_context.m_is_initialized) {
@@ -95,10 +108,16 @@ void OpenCL::ensure_context_initialized(OpenCLContext &opencl_context) {
cl::Kernel(m_program, "XgemmBatched");
opencl_context.m_out_transform_bn_kernel =
cl::Kernel(m_program, "out_transform_fused_bn");
+ opencl_context.m_sgemv_kernel =
+ cl::Kernel(m_program, "Xgemv");
opencl_context.m_out_transform_bn_in_kernel =
cl::Kernel(m_program, "out_transform_fused_bn_in");
opencl_context.m_commandqueue =
cl::CommandQueue(m_context, m_device);
+ opencl_context.m_global_avg_pooling_kernel =
+ cl::Kernel(m_program, "global_avg_pooling");
+ opencl_context.m_apply_se_kernel =
+ cl::Kernel(m_program, "apply_se");
opencl_context.m_is_initialized = true;
}
}
@@ -158,19 +177,25 @@ void OpenCL_Network::forward(const std::vector& input,
MAX_BATCH * m_ceil * m_ceil * max_channels * sizeof(net_t);
const auto alloc_vm_size =
MAX_BATCH * WINOGRAD_TILE * m_ceil * n_ceil * sizeof(net_t);
+ const auto alloc_pool_size = MAX_BATCH * max_channels * sizeof(net_t);
auto v_zeros = std::vector(alloc_vm_size);
opencl_context.m_inBuffer = cl::Buffer(
m_opencl.m_context,
CL_MEM_READ_WRITE, alloc_inSize);
+
opencl_context.m_inBuffer2 = cl::Buffer(
m_opencl.m_context,
CL_MEM_READ_WRITE, alloc_inSize);
+
+ // Zero pad the unused areas in V.
+ // Zeros must not be overwritten or convolution gives incorrect results.
opencl_context.m_VBuffer = cl::Buffer(
m_opencl.m_context,
CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS | CL_MEM_COPY_HOST_PTR,
alloc_vm_size, v_zeros.data(), nullptr);
+
opencl_context.m_MBuffer = cl::Buffer(
m_opencl.m_context,
CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, alloc_vm_size);
@@ -178,10 +203,16 @@ void OpenCL_Network::forward(const std::vector& input,
opencl_context.m_pinnedOutBuffer_pol = cl::Buffer(
m_opencl.m_context,
CL_MEM_WRITE_ONLY | CL_MEM_ALLOC_HOST_PTR, MAX_BATCH * finalSize_pol);
+
opencl_context.m_pinnedOutBuffer_val = cl::Buffer(
m_opencl.m_context,
CL_MEM_WRITE_ONLY | CL_MEM_ALLOC_HOST_PTR, MAX_BATCH * finalSize_val);
+ opencl_context.m_pool_buffer = cl::Buffer(
+ m_opencl.m_context,
+ CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS, alloc_pool_size);
+
+
opencl_context.m_buffers_allocated = true;
}
@@ -193,6 +224,7 @@ void OpenCL_Network::forward(const std::vector& input,
std::vector net_t_input(input.size());
std::copy(begin(input), end(input), begin(net_t_input));
+ cl::Buffer & pool_buffer = opencl_context.m_pool_buffer;
const auto inSize = sizeof(net_t) * input.size();
queue.enqueueWriteBuffer(inBuffer, CL_FALSE, 0, inSize, net_t_input.data());
@@ -222,16 +254,20 @@ void OpenCL_Network::forward(const std::vector& input,
nullptr,
bn_weights,
skip_in_trans, skip_next_in_trans, true,
+ true,
batch_size);
skip_in_trans = skip_next_in_trans;
} else if (layer.is_residual_block) {
assert(layer.channels == layer.outputs);
assert(niter != cend(m_layers));
+
auto conv1_weights = begin(layer.weights);
auto bn1_weights = begin(layer.weights) + 1;
- auto conv2_weights = begin(layer.weights) + 3;
- auto bn2_weights = begin(layer.weights) + 4;
+ auto conv2_weights = begin(layer.weights) + 4;
+ auto bn2_weights = begin(layer.weights) + 5;
+ auto se_weights = begin(layer.weights) + 7;
+
convolve3(opencl_context,
layer.channels,
layer.outputs,
@@ -243,24 +279,35 @@ void OpenCL_Network::forward(const std::vector& input,
nullptr,
bn1_weights,
skip_in_trans, true, false,
+ true,
batch_size);
auto skip_next_in_trans = false;
- if (niter->is_residual_block) {
- skip_next_in_trans = true;
- }
+
convolve3(opencl_context,
layer.channels,
layer.outputs,
inBuffer2,
- inBuffer,
+ inBuffer2,
VBuffer,
MBuffer,
conv2_weights,
- &inBuffer,
+ nullptr,
bn2_weights,
- true, skip_next_in_trans, true,
+ true, skip_next_in_trans, false,
+ false,
batch_size);
+
+ squeeze_excitation(opencl_context,
+ layer.outputs,
+ layer.se_fc_outputs,
+ inBuffer2,
+ pool_buffer,
+ MBuffer,
+ se_weights,
+ inBuffer,
+ batch_size);
+
skip_in_trans = skip_next_in_trans;
} else {
assert(layer.is_convolve1);
@@ -308,6 +355,123 @@ void OpenCL_Network::forward(const std::vector& input,
}
+template
+void OpenCL_Network::squeeze_excitation(
+ OpenCLContext & opencl_context,
+ int channels,
+ int fc_outputs,
+ cl::Buffer& bufferIn,
+ cl::Buffer& bufferTemp1,
+ cl::Buffer& bufferTemp2,
+ weight_slice_t weights,
+ cl::Buffer& bufferResidual,
+ int batch_size) {
+
+ cl::Kernel & pooling_kernel = opencl_context.m_global_avg_pooling_kernel;
+ cl::Kernel & apply_se_kernel = opencl_context.m_apply_se_kernel;
+ cl::CommandQueue & queue = opencl_context.m_commandqueue;
+
+ try {
+ pooling_kernel.setArg(0, channels);
+ pooling_kernel.setArg(1, bufferIn);
+ pooling_kernel.setArg(2, bufferTemp1);
+
+ queue.enqueueNDRangeKernel(pooling_kernel, cl::NullRange,
+ cl::NDRange(BOARD_SIZE, batch_size * channels),
+ cl::NDRange(BOARD_SIZE, 1));
+ } catch (const cl::Error &e) {
+ std::cerr << "Error in squeeze_excitation: " << e.what() << ": "
+ << e.err() << std::endl;
+ throw;
+ }
+
+ innerproduct(opencl_context,
+ bufferTemp1,
+ weights[0],
+ weights[1],
+ bufferTemp2,
+ channels,
+ fc_outputs,
+ true,
+ batch_size);
+
+ innerproduct(opencl_context,
+ bufferTemp2,
+ weights[2],
+ weights[3],
+ bufferTemp1,
+ fc_outputs,
+ channels,
+ false,
+ batch_size);
+
+ try {
+ apply_se_kernel.setArg(0, channels);
+ apply_se_kernel.setArg(1, bufferIn);
+ apply_se_kernel.setArg(2, bufferResidual);
+ apply_se_kernel.setArg(3, bufferTemp1);
+ apply_se_kernel.setArg(4, weights[4]);
+
+ queue.enqueueNDRangeKernel(apply_se_kernel, cl::NullRange,
+ cl::NDRange(BOARD_SIZE, batch_size * channels));
+ } catch (const cl::Error &e) {
+ std::cerr << "Error in squeeze_excitation: " << e.what() << ": "
+ << e.err() << std::endl;
+ throw;
+ }
+
+}
+
+template
+void OpenCL_Network::innerproduct(
+ OpenCLContext & opencl_context,
+ const cl::Buffer& input,
+ const cl::Buffer& weights,
+ const cl::Buffer& biases,
+ cl::Buffer& output,
+ int inputs, int outputs,
+ bool relu,
+ int batch_size) {
+
+ // TODO: innerproduct batching
+ assert(batch_size == 1);
+
+ auto sgemv_kernel = opencl_context.m_sgemv_kernel;
+ cl::CommandQueue & queue = opencl_context.m_commandqueue;
+
+ //TODO: Tune these
+ size_t wgs1 = 32;
+ size_t wpt1 = 1;
+
+ auto m_ceil = int(ceilMultiple(outputs, wgs1 * wpt1));
+ auto global_size = m_ceil / wpt1;
+ auto local_size = wgs1;
+
+ try {
+ // Sets the kernel arguments
+ sgemv_kernel.setArg(0, static_cast(outputs));
+ sgemv_kernel.setArg(1, static_cast(inputs));
+ sgemv_kernel.setArg(2, weights);
+ sgemv_kernel.setArg(3, static_cast(0));
+ sgemv_kernel.setArg(4, static_cast(inputs));
+ sgemv_kernel.setArg(5, input);
+ sgemv_kernel.setArg(6, static_cast(0));
+ sgemv_kernel.setArg(7, output);
+ sgemv_kernel.setArg(8, static_cast(0));
+ sgemv_kernel.setArg(9, biases);
+ sgemv_kernel.setArg(10, static_cast(relu));
+
+ queue.enqueueNDRangeKernel(sgemv_kernel, cl::NullRange,
+ cl::NDRange(global_size),
+ cl::NDRange(local_size));
+ } catch (const cl::Error &e) {
+ std::cerr << "Error in innerproduct: " << e.what() << ": "
+ << e.err() << std::endl;
+ throw;
+ }
+}
+
+
template
void OpenCL_Network::convolve3(OpenCLContext & opencl_context,
int channels, int outputs,
@@ -321,6 +485,7 @@ void OpenCL_Network::convolve3(OpenCLContext & opencl_context,
bool skip_in_transform,
bool fuse_in_transform,
bool store_inout,
+ bool relu,
int batch_size) {
cl::Kernel & in_transform_kernel = opencl_context.m_in_transform_kernel;
@@ -370,8 +535,12 @@ void OpenCL_Network::convolve3(OpenCLContext & opencl_context,
in_transform_kernel.setArg(4, n_ceil);
in_transform_kernel.setArg(5, batch_size);
+ // No relu not implemented
+ assert(relu);
+
queue.enqueueNDRangeKernel(in_transform_kernel, cl::NullRange,
- cl::NDRange(wgs, channels));
+ cl::NDRange(wgs, channels),
+ cl::NDRange(wgs, 1));
} catch (const cl::Error &e) {
std::cerr << "Error in convolve3: " << e.what() << ": "
<< e.err() << std::endl;
@@ -425,7 +594,8 @@ void OpenCL_Network::convolve3(OpenCLContext & opencl_context,
}
out_transform_bn_in_kernel.setArg(8, bn_weights[0]);
out_transform_bn_in_kernel.setArg(9, bn_weights[1]);
- out_transform_bn_in_kernel.setArg(10,
+ out_transform_bn_in_kernel.setArg(10, bn_weights[2]);
+ out_transform_bn_in_kernel.setArg(11,
cl::Local(dim_size * width * height * sizeof(float)));
queue.enqueueNDRangeKernel(out_transform_bn_in_kernel,
@@ -446,6 +616,8 @@ void OpenCL_Network::convolve3(OpenCLContext & opencl_context,
}
out_transform_bn_kernel.setArg(7, bn_weights[0]);
out_transform_bn_kernel.setArg(8, bn_weights[1]);
+ out_transform_bn_kernel.setArg(9, bn_weights[2]);
+ out_transform_bn_kernel.setArg(10, static_cast(relu));
queue.enqueueNDRangeKernel(out_transform_bn_kernel, cl::NullRange,
cl::NDRange(outputs, wgs));
@@ -761,7 +933,10 @@ void OpenCL::initialize(const int channels, int gpu, bool silent) {
+ sourceCode_config
+ sourceCode_convolve1
+ sourceCode_convolve3
- + sourceCode_sgemm);
+ + sourceCode_sgemm
+ + sourceCode_global_avg_pooling
+ + sourceCode_sgemv
+ + sourceCode_apply_se);
} catch (const cl::Error &e) {
myprintf("Error getting kernels: %s: %d", e.what(), e.err());
throw std::runtime_error("Error getting OpenCL kernels.");
diff --git a/src/OpenCL.h b/src/OpenCL.h
index d18e190da..7126f72a1 100644
--- a/src/OpenCL.h
+++ b/src/OpenCL.h
@@ -42,6 +42,7 @@ class Layer {
private:
unsigned int channels{0};
unsigned int outputs{0};
+ unsigned int se_fc_outputs{0};
unsigned int filter_size{0};
bool is_input_convolution{false};
bool is_residual_block{false};
@@ -59,12 +60,16 @@ class OpenCLContext {
cl::Kernel m_merge_kernel;
cl::Kernel m_in_transform_kernel;
cl::Kernel m_sgemm_kernel;
+ cl::Kernel m_sgemv_kernel;
cl::Kernel m_out_transform_bn_kernel;
cl::Kernel m_out_transform_bn_in_kernel;
+ cl::Kernel m_global_avg_pooling_kernel;
+ cl::Kernel m_apply_se_kernel;
cl::Buffer m_inBuffer;
cl::Buffer m_inBuffer2;
cl::Buffer m_VBuffer;
cl::Buffer m_MBuffer;
+ cl::Buffer m_pool_buffer;
cl::Buffer m_pinnedOutBuffer_pol;
cl::Buffer m_pinnedOutBuffer_val;
bool m_buffers_allocated{false};
@@ -83,11 +88,13 @@ class OpenCL_Network {
unsigned int outputs,
const std::vector& weights,
const std::vector& means,
- const std::vector& variances) {
+ const std::vector& variances,
+ const std::vector& prelu_alphas) {
size_t layer = get_layer_count();
push_weights(layer, weights);
push_weights(layer, means);
push_weights(layer, variances);
+ push_weights(layer, prelu_alphas);
m_layers[layer].is_input_convolution = true;
m_layers[layer].outputs = outputs;
m_layers[layer].filter_size = filter_size;
@@ -97,21 +104,35 @@ class OpenCL_Network {
void push_residual(unsigned int filter_size,
unsigned int channels,
unsigned int outputs,
+ unsigned int se_fc_outputs,
const std::vector& weights_1,
const std::vector& means_1,
const std::vector& variances_1,
+ const std::vector& prelu_alphas_1,
const std::vector& weights_2,
const std::vector& means_2,
- const std::vector& variances_2) {
+ const std::vector& variances_2,
+ const std::vector& prelu_alphas_2,
+ const std::vector& se_fc1_w,
+ const std::vector& se_fc1_b,
+ const std::vector& se_fc2_w,
+ const std::vector& se_fc2_b) {
size_t layer = get_layer_count();
push_weights(layer, weights_1);
push_weights(layer, means_1);
push_weights(layer, variances_1);
+ push_weights(layer, prelu_alphas_1);
push_weights(layer, weights_2);
push_weights(layer, means_2);
push_weights(layer, variances_2);
+ push_weights(layer, se_fc1_w);
+ push_weights(layer, se_fc1_b);
+ push_weights(layer, se_fc2_w);
+ push_weights(layer, se_fc2_b);
+ push_weights(layer, prelu_alphas_2);
m_layers[layer].is_residual_block = true;
m_layers[layer].outputs = outputs;
+ m_layers[layer].se_fc_outputs = se_fc_outputs;
m_layers[layer].filter_size = filter_size;
m_layers[layer].channels = channels;
}
@@ -158,6 +179,26 @@ class OpenCL_Network {
weight_slice_t bn_weights,
bool skip_in_transform,
bool fuse_in_transform, bool store_inout,
+ bool relu,
+ int batch_size);
+
+ void squeeze_excitation(OpenCLContext & opencl_context,
+ int channels,
+ int fc_outputs,
+ cl::Buffer& bufferIn,
+ cl::Buffer& bufferTemp1,
+ cl::Buffer& bufferTemp2,
+ weight_slice_t weights,
+ cl::Buffer& bufferResidual,
+ int batch_size);
+
+ void innerproduct(OpenCLContext & opencl_context,
+ const cl::Buffer& input,
+ const cl::Buffer& weights,
+ const cl::Buffer& biases,
+ cl::Buffer& output,
+ int inputs, int outputs,
+ bool relu,
int batch_size);
void convolve1(OpenCLContext & opencl_context,
diff --git a/src/OpenCLScheduler.cpp b/src/OpenCLScheduler.cpp
index 11ead3f35..7dfff9a94 100644
--- a/src/OpenCLScheduler.cpp
+++ b/src/OpenCLScheduler.cpp
@@ -96,7 +96,8 @@ void OpenCLScheduler::push_input_convolution(unsigned int filter_size,
unsigned int outputs,
const std::vector& weights,
const std::vector& means,
- const std::vector& variances) {
+ const std::vector& variances,
+ const std::vector& prelu_alphas) {
for (const auto& opencl_net : m_networks) {
const auto tuners = opencl_net->getOpenCL().get_sgemm_tuners();
@@ -112,7 +113,7 @@ void OpenCLScheduler::push_input_convolution(unsigned int filter_size,
m_ceil, k_ceil);
opencl_net->push_input_convolution(
filter_size, channels, outputs,
- Upad, means, variances
+ Upad, means, variances, prelu_alphas
);
}
}
@@ -121,12 +122,19 @@ template
void OpenCLScheduler::push_residual(unsigned int filter_size,
unsigned int channels,
unsigned int outputs,
+ unsigned int se_fc_outputs,
const std::vector& weights_1,
const std::vector& means_1,
const std::vector& variances_1,
+ const std::vector& prelu_alphas_1,
const std::vector& weights_2,
const std::vector& means_2,
- const std::vector& variances_2) {
+ const std::vector& variances_2,
+ const std::vector& prelu_alphas_2,
+ const std::vector& se_fc1_w,
+ const std::vector& se_fc1_b,
+ const std::vector& se_fc2_w,
+ const std::vector& se_fc2_b) {
for (const auto& opencl_net : m_networks) {
const auto tuners = opencl_net->getOpenCL().get_sgemm_tuners();
@@ -141,12 +149,19 @@ void OpenCLScheduler::push_residual(unsigned int filter_size,
outputs, outputs,
m_ceil, m_ceil);
opencl_net->push_residual(filter_size, channels, outputs,
+ se_fc_outputs,
Upad1,
means_1,
variances_1,
+ prelu_alphas_1,
Upad2,
means_2,
- variances_2);
+ variances_2,
+ prelu_alphas_2,
+ se_fc1_w,
+ se_fc1_b,
+ se_fc2_w,
+ se_fc2_b);
}
}
diff --git a/src/OpenCLScheduler.h b/src/OpenCLScheduler.h
index 702b23be9..4d1cd8526 100644
--- a/src/OpenCLScheduler.h
+++ b/src/OpenCLScheduler.h
@@ -43,22 +43,31 @@ class OpenCLScheduler : public ForwardPipe {
std::vector& output_pol,
std::vector& output_val);
+
virtual void push_input_convolution(unsigned int filter_size,
unsigned int channels,
unsigned int outputs,
const std::vector& weights,
const std::vector& means,
- const std::vector& variances);
+ const std::vector& variances,
+ const std::vector& prelu_alphas);
virtual void push_residual(unsigned int filter_size,
unsigned int channels,
unsigned int outputs,
+ unsigned int se_fc_outputs,
const std::vector& weights_1,
const std::vector& means_1,
const std::vector& variances_1,
+ const std::vector& prelu_alphas_1,
const std::vector& weights_2,
const std::vector& means_2,
- const std::vector& variances_2);
+ const std::vector& variances_2,
+ const std::vector& prelu_alphas_2,
+ const std::vector& se_fc1_w,
+ const std::vector& se_fc1_b,
+ const std::vector& se_fc2_w,
+ const std::vector& se_fc2_b);
virtual void push_convolve(unsigned int filter_size,
unsigned int channels,
diff --git a/src/Training.cpp b/src/Training.cpp
index f9421d9e5..e777bdced 100644
--- a/src/Training.cpp
+++ b/src/Training.cpp
@@ -236,7 +236,7 @@ void Training::dump_training(int winner_color, OutputChunker& outchunk) {
for (const auto& step : m_data) {
auto out = std::stringstream{};
// First output 16 times an input feature plane
- for (auto p = size_t{0}; p < 16; p++) {
+ for (auto p = size_t{0}; p < 16 + 1 + 8 + 2; p++) {
const auto& plane = step.planes[p];
// Write it out as a string of hex characters
for (auto bit = size_t{0}; bit + 3 < plane.size(); bit += 4) {
@@ -398,3 +398,247 @@ void Training::dump_supervised(const std::string& sgf_name,
std::cout << "Dumped " << train_pos << " training positions." << std::endl;
}
+
+static int idx_to_vertex(FullBoard &board, const int idx) {
+ if (idx == BOARD_SIZE*BOARD_SIZE) {
+ return FastBoard::PASS;
+ }
+ auto x = idx % BOARD_SIZE;
+ auto y = idx / BOARD_SIZE;
+ return board.get_vertex(x, y);
+}
+
+static unsigned char hex2int(unsigned char ch)
+{
+ if (ch >= '0' && ch <= '9')
+ return ch - '0';
+ if (ch >= 'A' && ch <= 'F')
+ return ch - 'A' + 10;
+ if (ch >= 'a' && ch <= 'f')
+ return ch - 'a' + 10;
+ return -1;
+}
+
+void Training::add_features(const std::string& training_file,
+ const std::string& out_filename) {
+ /* Add feature planes to Leela Zero training data */
+
+ auto in = gzopen(training_file.c_str(), "r");
+
+ if (!in) {
+ Utils::myprintf("Failed to open file %s\n", training_file.c_str());
+ return;
+ }
+
+ auto outchunker = OutputChunker{out_filename, true};
+
+ std::stringstream stream;
+ const unsigned int bufsize = 10000;
+ std::vector buffer(bufsize);
+
+ while (true) {
+ auto bytes = gzread(in, &buffer[0], bufsize);
+ if (bytes == 0) {
+ break;
+ }
+ if (bytes < 0) {
+ Utils::myprintf("gzread error: %s\n", gzerror(in, &bytes));
+ return;
+ }
+ stream.write(buffer.data(), bytes);
+ }
+ gzclose(in);
+
+ auto training_str = std::string{};
+
+ while (true) {
+ auto line = std::string{};
+ auto planecount = 0;
+ auto turn = 0;
+ auto winner = 0;
+ std::vector policy;
+ std::vector