diff --git a/src/GTP.cpp b/src/GTP.cpp index cd018eee3..2bdd8613d 100644 --- a/src/GTP.cpp +++ b/src/GTP.cpp @@ -858,8 +858,7 @@ void GTP::execute(GameState & game, const std::string& xinput) { } } else if (symmetry == "average" || symmetry == "avg") { vec = s_network->get_output( - &game, Network::Ensemble::AVERAGE, - Network::NUM_SYMMETRIES, false); + &game, Network::Ensemble::AVERAGE, -1, false); } else { vec = s_network->get_output( &game, Network::Ensemble::DIRECT, std::stoi(symmetry), false); diff --git a/src/Network.cpp b/src/Network.cpp index fd4bea544..d4bc36f7a 100644 --- a/src/Network.cpp +++ b/src/Network.cpp @@ -743,6 +743,7 @@ Network::Netresult Network::get_output( assert(symmetry >= 0 && symmetry < NUM_SYMMETRIES); result = get_output_internal(state, symmetry); } else if (ensemble == AVERAGE) { + assert(symmetry == -1); for (auto sym = 0; sym < NUM_SYMMETRIES; ++sym) { auto tmpresult = get_output_internal(state, sym); result.winrate += diff --git a/src/Training.cpp b/src/Training.cpp index 9228ae6e4..c714b6227 100644 --- a/src/Training.cpp +++ b/src/Training.cpp @@ -169,8 +169,8 @@ void Training::record(Network & network, GameState& state, UCTNode& root) { step.to_move = state.board.get_to_move(); step.planes = get_planes(&state); - auto result = - network.get_output(&state, Network::Ensemble::DIRECT, 0); + const auto result = network.get_output( + &state, Network::Ensemble::DIRECT, Network::IDENTITY_SYMMETRY); step.net_winrate = result.winrate; const auto& best_node = root.get_best_root_child(step.to_move);