-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_network.cpp
65 lines (54 loc) · 1.54 KB
/
test_network.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <iostream>
#include <vector>
#include <string>
#include <torch/torch.h>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/torch.h>
#include "libtorch.h"
#include "GameField.h"
#include "MCTS.h"
#include <chrono>
using namespace std;
using namespace chrono;
int main()
{
try {
//auto network(
// std::make_shared<torch::jit::script::Module>(
// torch::jit::load("D:\\Project\\py\\torchGo\\models\\checkpoint.pt")
// )
//);
////cout << network << endl;
//std::vector<torch::jit::IValue> inputs = { torch::ones({ 2, 3, 8, 8 }) };
////cout << inputs[0] << endl;
//auto result = network->forward(inputs).toTuple();
////cout << result << endl;
//torch::Tensor p_batch = result->elements()[0]
// .toTensor()
// .exp()
// .toType(torch::kFloat32)
// .to(at::kCPU);
//torch::Tensor v_batch =
// result->elements()[1].toTensor().toType(torch::kFloat32).to(at::kCPU);
//cout << p_batch << endl;
//cout << v_batch << endl;
GameField g;
NeuralNetwork net("./models/checkpoint.pt", true, 64);
MCTS mcts(&net, 12, 5, 800, 3, 65);
g.print();
auto start = system_clock::now();
auto p = mcts.get_action_probs(&g, 1);
auto end = system_clock::now();
auto duration = duration_cast<microseconds>(end - start);
cout << double(duration.count()) * microseconds::period::num / microseconds::period::den
<< " seconds" << endl;
for (auto i : p)
{
cout << i << endl;
}
}
catch (exception& e)
{
cout << e.what() << endl;
}
}