Skip to content

Commit

Permalink
Minor cleanup involving Network::get_output.
Browse files Browse the repository at this point in the history
Pull request leela-zero#2228.
  • Loading branch information
TFiFiE authored and gcp committed Feb 19, 2019
1 parent d6db69f commit dab65c8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
3 changes: 1 addition & 2 deletions src/GTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,7 @@ void GTP::execute(GameState & game, const std::string& xinput) {
}
} else if (symmetry == "average" || symmetry == "avg") {
vec = s_network->get_output(
&game, Network::Ensemble::AVERAGE,
Network::NUM_SYMMETRIES, false);
&game, Network::Ensemble::AVERAGE, -1, false);
} else {
vec = s_network->get_output(
&game, Network::Ensemble::DIRECT, std::stoi(symmetry), false);
Expand Down
1 change: 1 addition & 0 deletions src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ Network::Netresult Network::get_output(
assert(symmetry >= 0 && symmetry < NUM_SYMMETRIES);
result = get_output_internal(state, symmetry);
} else if (ensemble == AVERAGE) {
assert(symmetry == -1);
for (auto sym = 0; sym < NUM_SYMMETRIES; ++sym) {
auto tmpresult = get_output_internal(state, sym);
result.winrate +=
Expand Down
4 changes: 2 additions & 2 deletions src/Training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ void Training::record(Network & network, GameState& state, UCTNode& root) {
step.to_move = state.board.get_to_move();
step.planes = get_planes(&state);

auto result =
network.get_output(&state, Network::Ensemble::DIRECT, 0);
const auto result = network.get_output(
&state, Network::Ensemble::DIRECT, Network::IDENTITY_SYMMETRY);
step.net_winrate = result.winrate;

const auto& best_node = root.get_best_root_child(step.to_move);
Expand Down

0 comments on commit dab65c8

Please sign in to comment.