Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix imprecision when winrate is extreme #16

Open
wants to merge 14 commits into
base: br2gpu
Choose a base branch
from
10 changes: 10 additions & 0 deletions autogtp/Management.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,16 @@ void Management::runTuningProcess(const QString &tuneCmdLine) {
tuneProcess.waitForStarted(-1);
while (tuneProcess.state() == QProcess::Running) {
tuneProcess.waitForReadyRead(1000);
QByteArray text = tuneProcess.readAllStandardOutput();
int version_start = text.indexOf("Leela Zero ") + 11;
if (version_start > 10) {
int version_end = text.indexOf(" ", version_start);
m_leelaversion = QString(text.mid(version_start, version_end - version_start));
}
QTextStream(stdout) << text;
QTextStream(stdout) << tuneProcess.readAllStandardError();
}
QTextStream(stdout) << "Found Leela Version : " << m_leelaversion << endl;
tuneProcess.waitForFinished(-1);
}

Expand Down Expand Up @@ -316,6 +324,8 @@ Order Management::getWorkInternal(bool tuning) {
prog_cmdline.append("0");
} else {
prog_cmdline.append(QString::number(AUTOGTP_VERSION));
if (!m_leelaversion.isEmpty())
prog_cmdline.append("/"+m_leelaversion);
}
QProcess curl;
curl.start(prog_cmdline);
Expand Down
1 change: 1 addition & 0 deletions autogtp/Management.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public slots:
int m_threadsLeft;
bool m_delNetworks;
QLockFile *m_lockFile;
QString m_leelaversion;

Order getWorkInternal(bool tuning);
Order getWork(bool tuning = false);
Expand Down
56 changes: 54 additions & 2 deletions src/GTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ FILE* cfg_logfile_handle;
bool cfg_quiet;
std::string cfg_options_str;
bool cfg_benchmark;
int cfg_analyze_interval_centis;

void GTP::setup_default_parameters() {
cfg_gtp_mode = false;
Expand Down Expand Up @@ -107,6 +108,7 @@ void GTP::setup_default_parameters() {
cfg_logfile_handle = nullptr;
cfg_quiet = false;
cfg_benchmark = false;
cfg_analyze_interval_centis = 0;

// C++11 doesn't guarantee *anything* about how random this is,
// and in MinGW it isn't random at all. But we can mix it in, which
Expand Down Expand Up @@ -147,6 +149,8 @@ const std::string GTP::s_commands[] = {
"kgs-time_settings",
"kgs-game_over",
"heatmap",
"lz-analyze",
"lz-genmove_analyze",
""
};

Expand Down Expand Up @@ -345,12 +349,18 @@ bool GTP::execute(GameState & game, std::string xinput) {
}
}
return true;
} else if (command.find("genmove") == 0) {
} else if (command.find("genmove") == 0 || command.find("lz-genmove_analyze") == 0) {
auto analysis_output = command.find("lz-genmove_analyze") == 0;
auto interval = 0;

std::istringstream cmdstream(command);
std::string tmp;

cmdstream >> tmp; // eat genmove
cmdstream >> tmp;
if (analysis_output) {
cmdstream >> interval;
}

if (!cmdstream.fail()) {
int who;
Expand All @@ -362,24 +372,66 @@ bool GTP::execute(GameState & game, std::string xinput) {
gtp_fail_printf(id, "syntax error");
return 1;
}
if (analysis_output) {
// Start of multi-line response
cfg_analyze_interval_centis = interval;
if (id != -1) gtp_printf_raw("=%d\n", id);
else gtp_printf_raw("=\n");
}
// start thinking
{
game.set_to_move(who);
// Outputs winrate and pvs for lz-genmove_analyze
int move = search->think(who);
game.play_move(move);

std::string vertex = game.move_to_text(move);
gtp_printf(id, "%s", vertex.c_str());
if (!analysis_output) {
gtp_printf(id, "%s", vertex.c_str());
} else {
gtp_printf_raw("play %s\n", vertex.c_str());
}
}
if (cfg_allow_pondering) {
// now start pondering
if (!game.has_resigned()) {
// Outputs winrate and pvs through gtp for lz-genmove_analyze
search->ponder();
}
}
if (analysis_output) {
// Terminate multi-line response
gtp_printf_raw("\n");
}
} else {
gtp_fail_printf(id, "syntax not understood");
}
analysis_output = false;
return true;
} else if (command.find("lz-analyze") == 0) {
std::istringstream cmdstream(command);
std::string tmp;
int interval;

cmdstream >> tmp; // eat lz-analyze
cmdstream >> interval;
if (!cmdstream.fail()) {
cfg_analyze_interval_centis = interval;
} else {
gtp_fail_printf(id, "syntax not understood");
return true;
}
// Start multi-line response
if (id != -1) gtp_printf_raw("=%d\n", id);
else gtp_printf_raw("=\n");
// now start pondering
if (!game.has_resigned()) {
// Outputs winrate and pvs through gtp
search->ponder();
}
cfg_analyze_interval_centis = 0;
// Terminate multi-line response
gtp_printf_raw("\n");
return true;
} else if (command.find("kgs-genmove_cleanup") == 0) {
std::istringstream cmdstream(command);
Expand Down
1 change: 1 addition & 0 deletions src/GTP.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ extern FILE* cfg_logfile_handle;
extern bool cfg_quiet;
extern std::string cfg_options_str;
extern bool cfg_benchmark;
extern int cfg_analyze_interval_centis;

/*
A list of all valid GTP2 commands is defined here:
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ debug:

clang:
@echo "Detected OS: ${THE_OS}"
$(MAKE) CC=clang-5.0 CXX=clang++-5.0 \
$(MAKE) CC=clang CXX=clang++ \
CXXFLAGS='$(CXXFLAGS) -Wall -Wextra -Wno-missing-braces -O3 -ffast-math -flto -march=native -std=c++14 -DNDEBUG' \
LDFLAGS='$(LDFLAGS) -flto -fuse-linker-plugin' \
leelaz
Expand Down
16 changes: 11 additions & 5 deletions src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ Network::Netresult Network::get_scored_moves(
for (auto sym = 0; sym < 8; ++sym) {
auto tmpresult = get_scored_moves_internal(state, sym);
result.winrate += tmpresult.winrate / 8.0f;
result.opp_winrate += tmpresult.opp_winrate / 8.0f;
result.policy_pass += tmpresult.policy_pass / 8.0f;

for (auto idx = size_t{0}; idx < BOARD_SQUARES; idx++) {
Expand All @@ -905,12 +906,15 @@ Network::Netresult Network::get_scored_moves(
result = get_scored_moves_internal(state, rand_sym);
}

// v2 format (ELF Open Go) returns black value, not stm
if (value_head_not_stm) {
// v2 format (ELF Open Go) returns black value, not side-to-move
if (!value_head_not_stm) {
if (state->board.get_to_move() == FastBoard::WHITE) {
result.winrate = 1.0f - result.winrate;
auto temp_winrate = result.winrate;
result.winrate = result.opp_winrate;
result.opp_winrate = temp_winrate;
}
}
// now winrate is black value and opp_winrate is white value

// Insert result into cache.
NNCache::get_NNCache().insert(state->board.get_hash(), result);
Expand Down Expand Up @@ -970,8 +974,9 @@ Network::Netresult Network::get_scored_moves_internal(
const auto winrate_out =
innerproduct<256, 1, false>(winrate_data, ip2_val_w, ip2_val_b);

// Sigmoid
const auto winrate_sig = (1.0f + std::tanh(winrate_out[0])) / 2.0f;
// Sigmoid: tanh normalized to take value in (0,1)
const auto winrate_sig = 1.0f / (1.0f + std::exp(-2.0f * winrate_out[0]));
const auto opp_winrate_sig = 1.0f / (1.0f + std::exp(2.0f * winrate_out[0]));

Netresult result;

Expand All @@ -982,6 +987,7 @@ Network::Netresult Network::get_scored_moves_internal(

result.policy_pass = outputs[BOARD_SQUARES];
result.winrate = winrate_sig;
result.opp_winrate = opp_winrate_sig;

return result;
}
Expand Down
5 changes: 3 additions & 2 deletions src/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class Network {
// pass
float policy_pass;

// winrate
// winrates
float winrate;
float opp_winrate;

Netresult() : policy(BOARD_SQUARES), policy_pass(0.0f), winrate(0.0f) {}
Netresult() : policy(BOARD_SQUARES), policy_pass(0.0f), winrate(0.0f), opp_winrate(0.0f) {}
};

static Netresult get_scored_moves(const GameState* const state,
Expand Down
12 changes: 12 additions & 0 deletions src/UCTNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ int UCTNode::get_visits() const {
return m_visits;
}

// Return the true score, without taking into account virtual losses.
float UCTNode::get_pure_eval(int tomove) const {
auto visits = get_visits();
assert(visits > 0);
auto blackeval = get_blackevals();
auto score = static_cast<float>(blackeval / double(visits));
if (tomove == FastBoard::WHITE) {
score = 1.0f - score;
}
return score;
}

float UCTNode::get_eval(int tomove) const {
// Due to the use of atomic updates and virtual losses, it is
// possible for the visit count to change underneath us. Make sure
Expand Down
6 changes: 5 additions & 1 deletion src/UCTNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class UCTNode {
float get_score() const;
void set_score(float score);
float get_eval(int tomove) const;
float get_pure_eval(int tomove) const;
float get_net_eval(int tomove) const;
void virtual_loss(void);
void virtual_loss_undo(void);
Expand All @@ -92,6 +93,7 @@ class UCTNode {
std::vector<Network::ScoreVertexPair>& nodelist,
float min_psa_ratio);
double get_blackevals() const;
double get_whiteevals() const;
void accumulate_eval(float eval);
void kill_superkos(const KoState& state);
void dirichlet_noise(float epsilon, float alpha);
Expand All @@ -108,8 +110,10 @@ class UCTNode {
// UCT eval
float m_score;
// Original net eval for this node (not children).
float m_net_eval{0.0f};
float m_net_blackeval{0.0f};
float m_net_whiteeval{0.0f};
std::atomic<double> m_blackevals{0.0};
std::atomic<double> m_whiteevals{0.0};
std::atomic<Status> m_status{ACTIVE};
// Is someone adding scores to this node?
bool m_is_expanding{false};
Expand Down
50 changes: 46 additions & 4 deletions src/UCTSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,38 @@ void UCTSearch::dump_stats(FastState & state, UCTNode & parent) {
myprintf("%4s -> %7d (V: %5.2f%%) (N: %5.2f%%) PV: %s\n",
move.c_str(),
node->get_visits(),
node->get_visits() ? node->get_eval(color)*100.0f : 0.0f,
node->get_visits() ? node->get_pure_eval(color)*100.0f : 0.0f,
node->get_score() * 100.0f,
pv.c_str());
}
tree_stats(parent);
}

void UCTSearch::output_analysis(FastState & state, UCTNode & parent) {
if (!parent.has_children()) {
return;
}

const int color = state.get_to_move();

for (const auto& node : parent.get_children()) {
// Only send variations with visits
if (!node->get_visits()) continue;

std::string move = state.move_to_text(node->get_move());
FastState tmpstate = state;
tmpstate.play_move(node->get_move());
std::string pv = move + " " + get_pv(tmpstate, *node);
auto move_eval = node->get_visits() ?
static_cast<int>(node->get_pure_eval(color) * 10000) : 0;
gtp_printf_raw("info %s %s %s %d %s %d %s %s\n",
"move", move.c_str(),
"visits", node->get_visits(),
"winrate", move_eval,
"pv", pv.c_str());
}
}

void tree_stats_helper(const UCTNode& node, size_t depth,
size_t& nodes, size_t& non_leaf_nodes,
size_t& depth_sum, size_t& max_depth,
Expand Down Expand Up @@ -465,7 +490,7 @@ void UCTSearch::dump_analysis(int playouts) {
int color = tempstate.board.get_to_move();

std::string pvstring = get_pv(tempstate, *m_root);
float winrate = 100.0f * m_root->get_eval(color);
float winrate = 100.0f * m_root->get_pure_eval(color);
myprintf("Playouts: %d, Win: %5.2f%%, PV: %s\n",
playouts, winrate, pvstring.c_str());
}
Expand Down Expand Up @@ -600,8 +625,9 @@ int UCTSearch::think(int color, passflag_t passflag) {
tg.add_task(UCTWorker(m_rootstate, this, m_root.get()));
}

bool keeprunning = true;
int last_update = 0;
auto keeprunning = true;
auto last_update = 0;
auto last_output = 0;
do {
auto currstate = std::make_unique<GameState>(m_rootstate);

Expand All @@ -613,6 +639,12 @@ int UCTSearch::think(int color, passflag_t passflag) {
Time elapsed;
int elapsed_centis = Time::timediff_centis(start, elapsed);

if (cfg_analyze_interval_centis &&
elapsed_centis - last_output > cfg_analyze_interval_centis) {
last_output = elapsed_centis;
output_analysis(m_rootstate, *m_root);
}

// output some stats every few seconds
// check if we should still search
if (elapsed_centis - last_update > 250) {
Expand Down Expand Up @@ -670,13 +702,23 @@ void UCTSearch::ponder() {
for (int i = 1; i < cfg_num_threads; i++) {
tg.add_task(UCTWorker(m_rootstate, this, m_root.get()));
}
Time start;
auto keeprunning = true;
auto last_output = 0;
do {
auto currstate = std::make_unique<GameState>(m_rootstate);
auto result = play_simulation(*currstate, m_root.get());
if (result.valid()) {
increment_playouts();
}
if (cfg_analyze_interval_centis) {
Time elapsed;
int elapsed_centis = Time::timediff_centis(start, elapsed);
if (elapsed_centis - last_output > cfg_analyze_interval_centis) {
last_output = elapsed_centis;
output_analysis(m_rootstate, *m_root);
}
}
keeprunning = is_running();
keeprunning &= !stop_thinking(0, 1);
} while (!Utils::input_pending() && keeprunning);
Expand Down
1 change: 1 addition & 0 deletions src/UCTSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class UCTSearch {
int get_best_move(passflag_t passflag);
void update_root();
bool advance_to_new_rootstate();
void output_analysis(FastState & state, UCTNode & parent);

GameState & m_rootstate;
std::unique_ptr<GameState> m_last_rootstate;
Expand Down
Loading