-
Notifications
You must be signed in to change notification settings - Fork 91
/
inference.cpp
99 lines (84 loc) · 3.44 KB
/
inference.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// 2018, Patrick Wieschollek <[email protected]>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/public/session_options.h>
#include <iostream>
#include <string>
typedef std::vector<std::pair<std::string, tensorflow::Tensor>> tensor_dict;
/**
* @brief load a previous store model
* @details [long description]
*
* in Python run:
*
* saver = tf.train.Saver(tf.global_variables())
* saver.save(sess, './exported/my_model')
* tf.train.write_graph(sess.graph, '.', './exported/graph.pb, as_text=False)
*
* this relies on a graph which has an operation called `init` responsible to
* initialize all variables, eg.
*
* sess.run(tf.global_variables_initializer()) # somewhere in the python
* file
*
* @param sess active tensorflow session
* @param graph_fn path to graph file (eg. "./exported/graph.pb")
* @param checkpoint_fn path to checkpoint file (eg. "./exported/my_model",
* optional)
* @return status of reloading
*/
tensorflow::Status LoadModel(tensorflow::Session *sess, std::string graph_fn,
std::string checkpoint_fn = "") {
tensorflow::Status status;
// Read in the protobuf graph we exported
tensorflow::MetaGraphDef graph_def;
status = ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def);
if (status != tensorflow::Status::OK()) return status;
// create the graph
status = sess->Create(graph_def.graph_def());
if (status != tensorflow::Status::OK()) return status;
// restore model from checkpoint, iff checkpoint is given
if (checkpoint_fn != "") {
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING,
tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpoint_fn;
tensor_dict feed_dict = {
{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()},
nullptr);
if (status != tensorflow::Status::OK()) return status;
} else {
// virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
// const std::vector<string>& output_tensor_names,
// const std::vector<string>& target_node_names,
// std::vector<Tensor>* outputs) = 0;
status = sess->Run({}, {}, {"init"}, nullptr);
if (status != tensorflow::Status::OK()) return status;
}
return tensorflow::Status::OK();
}
int main(int argc, char const *argv[]) {
const std::string graph_fn = "./exported/my_model.meta";
const std::string checkpoint_fn = "./exported/my_model";
// prepare session
tensorflow::Session *sess;
tensorflow::SessionOptions options;
TF_CHECK_OK(tensorflow::NewSession(options, &sess));
TF_CHECK_OK(LoadModel(sess, graph_fn, checkpoint_fn));
// prepare inputs
tensorflow::TensorShape data_shape({1, 2});
tensorflow::Tensor data(tensorflow::DT_FLOAT, data_shape);
// same as in python file
auto data_ = data.flat<float>().data();
data_[0] = 42;
data_[1] = 43;
tensor_dict feed_dict = {
{"input_plhdr", data},
};
std::vector<tensorflow::Tensor> outputs;
TF_CHECK_OK(
sess->Run(feed_dict, {"sequential/Output_1/Softmax:0"}, {}, &outputs));
std::cout << "input " << data.DebugString() << std::endl;
std::cout << "output " << outputs[0].DebugString() << std::endl;
return 0;
}