Skip to content

Commit

Permalink
VN with dragon safety
Browse files Browse the repository at this point in the history
  • Loading branch information
zakki committed Jul 9, 2017
1 parent b3a5031 commit a6f04c7
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 19 deletions.
9 changes: 6 additions & 3 deletions src/GoBoard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1981,7 +1981,7 @@ WritePlanes(
std::vector<float>& data_basic,
std::vector<float>& data_features,
std::vector<float>& data_move,
std::vector<float>& data_safety,
std::vector<float>* data_safety,
std::vector<float>* data_owner,
const game_info_t *game,
const uct_node_t *root,
Expand Down Expand Up @@ -2071,16 +2071,19 @@ WritePlanes(
}

if (safety) {
data_safety.reserve(19 * 19 * 8);
data_safety->reserve(19 * 19 * 8);
for (int s = 0; s < 8; s++) {
for (int i = 1, y = board_start; y <= board_end; y++, i++) {
// cerr << setw(2) << (pure_board_size + 1 - i) << ":|";
for (int x = board_start; x <= board_end; x++) {
int pos = TransformMove(POS(x, y), tran);
OUTPUT_FEATURE(data_safety, safety[PureBoardPos(pos)] == s + 1);
OUTPUT_FEATURE((*data_safety), safety[PureBoardPos(pos)] == s + 1);
}
}
}
} else if (data_safety) {
cerr << "no safety data" << endl;
data_safety->resize(19 * 19 * 8);
}

if (data_owner) {
Expand Down
2 changes: 1 addition & 1 deletion src/GoBoard.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void GetNeighbor4( int neighbor4[4], const int pos );
struct uct_node_t;

void WritePlanes(std::vector<float>& data_basic, std::vector<float>& data_features,
std::vector<float>& data_move, std::vector<float>& data_safety,
std::vector<float>& data_move, std::vector<float>* data_safety,
std::vector<float>* data_owner,
const game_info_t *game, const uct_node_t *root,
const uint8_t safety[PURE_BOARD_MAX],
Expand Down
4 changes: 2 additions & 2 deletions src/Gtp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ void DumpFeature(const uct_node_t& node, int color, int move, int win)
int t = mt() % 8;
//static int t = 0; t++;
int moveT = RevTransformMove(move, t);
WritePlanes(data_basic, data_features, data_history, data_safety, &data_owner,
WritePlanes(data_basic, data_features, data_history, &data_safety, &data_owner,
game_prev, &node, stored_critical, color, t);

int x = CORRECT_X(moveT) - 1;
Expand Down Expand Up @@ -1139,7 +1139,7 @@ GTP_stat(void)
std::vector<float> data_safety;
uint8_t safety[PURE_BOARD_MAX];

WritePlanes(data_basic, data_features, data_history, data_safety, &data_owner,
WritePlanes(data_basic, data_features, data_history, &data_safety, &data_owner,
game, &store_node, safety, player_color, t);

std::vector<int> eval_node_index;
Expand Down
62 changes: 49 additions & 13 deletions src/UctSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ static std::queue<std::shared_ptr<gnugo_eval_req>> eval_gnugo_queue;
static std::set<std::pair<int, int>> gnugo_searched;
static int eval_count_policy, eval_count_value;
static double owner_nn[BOARD_MAX];
static uint8_t safety[PURE_BOARD_MAX];
static uint8_t previous_safety[PURE_BOARD_MAX];

static Microsoft::MSR::CNTK::IEvaluateModel<float>* nn_policy = nullptr;
static Microsoft::MSR::CNTK::IEvaluateModel<float>* nn_value = nullptr;
Expand Down Expand Up @@ -268,7 +268,7 @@ static int SelectMaxUcbChild(const game_info_t *game, int current, int color );
static void Statistic( game_info_t *game, int winner );

// UCT探索(1回の呼び出しにつき, 1回の探索)
static int UctSearch(game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& lgrctx, int current, int *winner, std::vector<int>& path);
static int UctSearch( game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& lgrctx, int current, int *winner, std::vector<int>& path, uint8_t *safety );

// 各ノードの統計情報の更新
static void UpdateNodeStatistic( game_info_t *game, int winner, statistic_t *node_statistic );
Expand Down Expand Up @@ -309,6 +309,13 @@ shared_ptr<gnugo_eval_req> CreateGnugoReq( const game_info_t* game )
return req;
}

static uint8_t* GetSafety(int index)
{
if (uct_node[index].valid_safety)
return uct_node[index].safety;
return nullptr;
}

/////////////////////
// 予測読みの設定 //
/////////////////////
Expand Down Expand Up @@ -549,6 +556,8 @@ InitializeSearchSetting( void )

pondered = false;
pondering_stop = true;

fill(begin(previous_safety), end(previous_safety), 0);
}


Expand Down Expand Up @@ -982,6 +991,7 @@ InitializeCandidate( child_node_t *uct_child, int pos, bool ladder )
uct_child->move_count = 0;
uct_child->win = 0;
uct_child->eval_value = false;
uct_child->eval_gnugo = false;
uct_child->index = NOT_EXPANDED;
uct_child->rate = 0.0;
uct_child->flag = false;
Expand Down Expand Up @@ -1038,6 +1048,7 @@ ExpandRoot( game_info_t *game, int color )
uct_child[i].move_count = 0;
uct_child[i].win = 0;
uct_child[i].eval_value = false;
uct_child[i].eval_gnugo = false;
}
uct_child[i].ladder = ladder[pos];
}
Expand Down Expand Up @@ -1069,6 +1080,7 @@ ExpandRoot( game_info_t *game, int color )
uct_node[index].evaled = false;
uct_node[index].value_move_count = 0;
uct_node[index].value_win = 0;
uct_node[index].valid_safety = false;
memset(uct_node[index].statistic, 0, sizeof(statistic_t) * BOARD_MAX);
fill_n(uct_node[index].seki, BOARD_MAX, false);
//fill_n(uct_node[index].safety, pure_board_max, 0);
Expand Down Expand Up @@ -1164,6 +1176,7 @@ ExpandNode( game_info_t *game, int color, int current, const std::vector<int>& p
uct_node[index].evaled = false;
uct_node[index].value_move_count = 0;
uct_node[index].value_win = 0;
uct_node[index].valid_safety = false;
memset(uct_node[index].statistic, 0, sizeof(statistic_t) * BOARD_MAX);
fill_n(uct_node[index].seki, BOARD_MAX, false);
//fill_n(uct_node[index].safety, pure_board_max, 0);
Expand Down Expand Up @@ -1275,7 +1288,7 @@ RatingNode( game_info_t *game, int color, int index, int depth )
req->index = index;
req->trans = rand() / (RAND_MAX / 8 + 1);
//req.path.swap(path);
WritePlanes(req->data_basic, req->data_features, req->data_history, req->data_safety, nullptr,
WritePlanes(req->data_basic, req->data_features, req->data_history, nullptr, nullptr,
game, root, nullptr, color, req->trans);
#if 1
eval_policy_queue.push(req);
Expand Down Expand Up @@ -1532,7 +1545,10 @@ ParallelUctSearch( thread_arg_t *arg )
// 1回プレイアウトする
//double value_result = -1;
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
auto safety = GetSafety(current_root);
if (!safety)
safety = previous_safety;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path, safety);
// 探索を打ち切るか確認
interruption = InterruptionCheck();
// ハッシュに余裕があるか確認
Expand Down Expand Up @@ -1579,7 +1595,10 @@ ParallelUctSearch( thread_arg_t *arg )
// 1回プレイアウトする
//double value_result = -1;
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
auto safety = GetSafety(current_root);
if (!safety)
safety = previous_safety;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path, safety);
// 探索を打ち切るか確認
interruption = InterruptionCheck();
// ハッシュに余裕があるか確認
Expand Down Expand Up @@ -1629,7 +1648,10 @@ ParallelUctSearchPondering( thread_arg_t *arg )
// 1回プレイアウトする
//double value_result = -1;
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
auto safety = GetSafety(current_root);
if (!safety)
safety = previous_safety;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path, safety);
// ハッシュに余裕があるか確認
enough_size = CheckRemainingHashSize();
// OwnerとCriticalityを計算する
Expand All @@ -1648,7 +1670,10 @@ ParallelUctSearchPondering( thread_arg_t *arg )
// 1回プレイアウトする
//double value_result = -1;
std::vector<int> path;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path);
auto safety = GetSafety(current_root);
if (!safety)
safety = previous_safety;
UctSearch(game, color, mt[targ->thread_id], lgr, lgr_ctx[targ->thread_id], current_root, &winner, path, safety);
// ハッシュに余裕があるか確認
enough_size = CheckRemainingHashSize();
} while (!pondering_stop && enough_size);
Expand All @@ -1665,7 +1690,7 @@ ParallelUctSearchPondering( thread_arg_t *arg )
// 1回の呼び出しにつき, 1プレイアウトする //
//////////////////////////////////////////////
static int
UctSearch(game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& lgrctx, int current, int *winner, std::vector<int>& path)
UctSearch(game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& lgrctx, int current, int *winner, std::vector<int>& path, uint8_t *safety)
{
int result = 0, next_index;
double score;
Expand Down Expand Up @@ -1701,7 +1726,7 @@ UctSearch(game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& l
bool end_of_game = game->moves > 2 &&
game->record[game->moves - 1].pos == PASS &&
game->record[game->moves - 2].pos == PASS;
#if 0
#if 1
if (uct_child[next_index].move_count > 1000
&& uct_child[next_index].move_count % 100 == 0
&& atomic_compare_exchange_strong(&uct_child[next_index].eval_gnugo, &expected, true)) {
Expand Down Expand Up @@ -1731,7 +1756,6 @@ UctSearch(game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& l
UNLOCK_NODE(current);

// Enqueue value

expected = false;
if (use_nn
&& (n >= expand_threshold * value_evaluation_threshold
Expand All @@ -1748,7 +1772,7 @@ UctSearch(game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& l
//req->index = index;
req->trans = rand() / (RAND_MAX / 8 + 1);
req->path.swap(path);
WritePlanes(req->data_basic, req->data_features, req->data_history, req->data_safety, nullptr,
WritePlanes(req->data_basic, req->data_features, req->data_history, &req->data_safety, nullptr,
game, root, safety, color, req->trans);
LOCK_EXPAND;
eval_value_queue.push(req);
Expand Down Expand Up @@ -1790,8 +1814,9 @@ UctSearch(game_info_t *game, int color, mt19937_64 *mt, LGR& lgrf, LGRContext& l
}
// 現在見ているノードのロックを解除
UNLOCK_NODE(current);
uint8_t *s = GetSafety(uct_child[next_index].index);
// 手番を入れ替えて1手深く読む
result = UctSearch(game, color, mt, lgrf, lgrctx, uct_child[next_index].index, winner, path);
result = UctSearch(game, color, mt, lgrf, lgrctx, uct_child[next_index].index, winner, path, s != nullptr ? s : safety);
//
// double v = uct_node[current].value;
// if (*value_result < 0 && v >= 0) {
Expand Down Expand Up @@ -1946,7 +1971,7 @@ SelectMaxUcbChild( const game_info_t *game, int current, int color )

const double p_p = (double)uct_node[current].win / uct_node[current].move_count;
const double p_v = (double)uct_node[current].value_win / (uct_node[current].value_move_count + .01);
const double scale = std::max(0.2, std::min(1.0, 1.0 - (game->moves - 200) / 50.0)) * value_scale;
const double scale = value_scale;// std::max(0.2, std::min(1.0, 1.0 - (game->moves - 200) / 50.0)) * value_scale;

int start_child = 0;
if (!early_pass && current == current_root && child_num > 1) {
Expand Down Expand Up @@ -2764,8 +2789,19 @@ SearchHint( gnugo_eval_req* req )

#if 1
{
auto& safety = uct_node[req->index].safety;
fill(begin(safety), end(safety), 0);
gnugo_analyze_dragon_status(req->moves.data(), safety);
uct_node[req->index].valid_safety = true;
if (req->index == current_root) {
copy(begin(safety), end(safety), previous_safety);
for (int y = 0; y < pure_board_size; y++) {
for (int x = 0; x < pure_board_size; x++) {
cerr << setw(2) << (int)previous_safety[y * pure_board_size + x];
}
cerr << endl;
}
}
double finish_time = GetSpendTime(begin_time);
cerr << "analyze " << finish_time << "sec" << endl;
return;
Expand Down
2 changes: 2 additions & 0 deletions src/UctSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ struct uct_node_t {
//std::atomic<double> value;
std::atomic<int> value_move_count;
std::atomic<double> value_win;
std::atomic<bool> valid_safety;
uint8_t safety[PURE_BOARD_MAX];
};

struct po_info_t {
Expand Down

0 comments on commit a6f04c7

Please sign in to comment.