-
Notifications
You must be signed in to change notification settings - Fork 5
/
predictor.cpp
565 lines (489 loc) · 17.7 KB
/
predictor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
#ifdef __linux__
#include <algorithm>
#include <iostream>
#include <vector>
#include <cuda_runtime_api.h>
#include "NvCaffeParser.h"
// #include "parserOnnxConfig.h"
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "NvOnnxConfig.h"
#include "NvOnnxParser.h"
#include "NvUffParser.h"
#include "json.hpp"
#include "predictor.hpp"
#include "timer.h"
#include "timer.impl.hpp"
#include "half.hpp"
using namespace std;
using namespace nvinfer1;
using namespace nvinfer1;
using namespace nvonnxparser;
using namespace nvcaffeparser1;
using namespace nvuffparser;
using std::string;
using json = nlohmann::json;
static bool has_error = false;
static std::string error_string{""};
static void clear_error() {
has_error = false;
error_string = "";
}
static void set_error(const std::string &err) {
has_error = true;
error_string = err;
}
#define START_C_DEFINION() \
clear_error(); \
try {
#define END_C_DEFINION(res) \
} \
catch (const std::exception &e) { \
std::cerr << "ERROR: " << e.what() << "\n"; \
set_error(e.what()); \
} \
catch (const std::string &e) { \
std::cerr << "ERROR: " << e << "\n"; \
set_error(e); \
} \
catch (...) { \
std::cerr << "ERROR: unknown exception in go-tensorrt" \
<< "\n"; \
set_error("unknown exception in go-tensorrt"); \
} \
clear_error(); \
return res
class Logger : public ILogger {
void log(Severity severity, const char *msg) override {
// suppress info-level messages
if (severity < Severity::kWARNING) {
std::cout << msg << std::endl;
}
}
} gLogger;
#define CHECK(stmt) stmt
#define CHECK_ERROR(stmt) stmt
class Profiler : public IProfiler {
public:
Profiler(profile *prof) : prof_(prof) {
if (prof_ == nullptr) {
return;
}
prof_->start(); // reset start time
current_time_ = prof_->get_start();
}
/** \brief layer time reporting callback
*
* \param layerName the name of the layer, set when constructing the network
* definition
* \param ms the time in milliseconds to execute the layer
*/
virtual void reportLayerTime(const char *layer_name, float ms) {
if (prof_ == nullptr) {
return;
}
shapes_t shapes{};
auto duration = std::chrono::nanoseconds((timestamp_t::rep)(1000000 * ms));
auto e = new profile_entry(current_layer_sequence_index_, layer_name, "",
shapes);
e->set_start(current_time_);
e->set_end(current_time_ + duration);
prof_->add(current_layer_sequence_index_ - 1, e);
current_layer_sequence_index_++;
current_time_ += duration;
}
virtual ~Profiler() {}
private:
profile *prof_{nullptr};
int current_layer_sequence_index_{1};
timestamp_t current_time_{};
};
class Predictor {
public:
Predictor(IExecutionContext *context,
std::vector<std::string> input_layer_names,
std::vector<std::string> output_layer_names, int32_t batch_size)
: context_(context), input_layer_names_(input_layer_names),
output_layer_names_(output_layer_names), batch_size_(batch_size) {
cudaStreamCreate(&stream_);
const ICudaEngine &engine = context_->getEngine();
data_.resize(engine.getNbBindings());
};
void Run() {
if (context_ == nullptr) {
throw std::runtime_error("tensorrt prediction error null context_");
}
const ICudaEngine &engine = context_->getEngine();
if (engine.getNbBindings() !=
input_layer_names_.size() + output_layer_names_.size()) {
throw std::runtime_error(std::string("tensorrt prediction error on ") +
std::to_string(__LINE__));
}
Profiler profiler(prof_);
// Set the custom profiler.
context_->setProfiler(&profiler);
context_->execute(batch_size_, data_.data());
// context_->enqueue(batch_size_, data_.data(), stream_, nullptr);
}
template <typename T>
void AddInput(const std::string &name, const T *host_data,
size_t num_elements) {
void *gpu_data = nullptr;
const ICudaEngine &engine = context_->getEngine();
const auto idx = engine.getBindingIndex(name.c_str());
if (idx == -1) {
throw std::runtime_error(std::string("invalid input name ") + name);
}
const auto byte_count = batch_size_ * num_elements * sizeof(T);
CHECK_ERROR(cudaMalloc(&gpu_data, byte_count));
CHECK_ERROR(cudaMemcpyAsync(gpu_data, host_data, byte_count,
cudaMemcpyHostToDevice, stream_));
data_[idx] = gpu_data;
}
template <typename T> void AddOutput(const std::string &name) {
void *gpu_data = nullptr;
const ICudaEngine &engine = context_->getEngine();
const auto idx = engine.getBindingIndex(name.c_str());
if (idx == -1) {
throw std::runtime_error(std::string("invalid output name ") + name);
}
const auto dims = engine.getBindingDimensions(idx);
const auto ndims = dims.nbDims;
auto num_elements = 1;
std::vector<int> res{};
for (int ii = 0; ii < ndims; ii++) {
num_elements *= dims.d[ii];
}
const auto byte_count = batch_size_ * num_elements * sizeof(T);
CHECK_ERROR(cudaMalloc(&gpu_data, byte_count));
data_[idx] = gpu_data;
}
void *GetOutputData(const std::string &name) {
synchronize();
const ICudaEngine &engine = context_->getEngine();
const auto idx = engine.getBindingIndex(name.c_str());
if (idx == -1) {
throw std::runtime_error(std::string("invalid output name ") + name);
}
if (engine.bindingIsInput(idx)) {
throw std::runtime_error(std::string("the layer name is not an output ") +
name);
}
const auto shape = GetOutputShape(name);
auto element_byte_count = 1;
const auto data_type = engine.getBindingDataType(idx);
const size_t num_elements =
std::accumulate(begin(shape), end(shape), 1, std::multiplies<size_t>());
#ifdef DEBUG
std::cout << "shape = " << shape[0] << "\n";
#endif
switch (data_type) {
#define DISPATCH_GET_OUTPUT(DType, CType) \
case DType: \
element_byte_count = sizeof(CType); \
break; \
TensorRT_DType_Dispatch(DISPATCH_GET_OUTPUT)
#undef DISPATCH_GET_OUTPUT
case DataType::kFLOAT:
element_byte_count = sizeof(float);
break;
case DataType::kHALF:
element_byte_count = sizeof(short);
break;
case DataType::kINT8:
element_byte_count = sizeof(int8_t);
break;
case DataType::kINT32:
element_byte_count = sizeof(int32_t);
break;
default:
throw std::runtime_error("unexpected output type");
}
const auto byte_count = num_elements * element_byte_count;
void *res_data = malloc(byte_count);
#ifdef DEBUG
std::cout << "byte_count = " << byte_count << "\n";
#endif
CHECK(cudaMemcpy(res_data, data_[idx], byte_count, cudaMemcpyDeviceToHost));
return res_data;
}
std::vector<int32_t> GetOutputShape(const std::string &name) {
synchronize();
const ICudaEngine &engine = context_->getEngine();
const auto idx = engine.getBindingIndex(name.c_str());
if (idx == -1) {
throw std::runtime_error(std::string("invalid output name ") + name);
}
const auto dims = engine.getBindingDimensions(idx);
const auto ndims = dims.nbDims;
#ifdef DEBUG
std::cout << __LINE__ << " >>> "
<< "name = " << name << "\n";
std::cout << __LINE__ << " >>> "
<< "ndims = " << ndims << "\n";
#endif
std::vector<int> res{};
for (int ii = 0; ii < ndims; ii++) {
#ifdef DEBUG
std::cout << __LINE__ << " >>> " << ii << " == " << dims.d[ii] << "\n";
#endif
res.emplace_back(dims.d[ii]);
}
#ifdef DEBUG
std::cout << __LINE__ << " >>> "
<< "res.size() = " << res.size() << "\n";
#endif
return res;
}
void synchronize() { CHECK(cudaStreamSynchronize(stream_)); }
~Predictor() {
for (auto data : data_) {
cudaFree(data);
}
if (context_) {
context_->destroy();
}
if (prof_) {
prof_->reset();
delete prof_;
prof_ = nullptr;
}
}
IExecutionContext *context_{nullptr};
std::vector<string> input_layer_names_{nullptr};
std::vector<string> output_layer_names_{nullptr};
int32_t batch_size_{1};
std::vector<void *> data_{nullptr};
cudaStream_t stream_{0};
profile *prof_{nullptr};
bool profile_enabled_{false};
};
Predictor *get_predictor_from_handle(PredictorHandle predictor_handle) {
auto predictor = (Predictor *)predictor_handle;
if (predictor == nullptr) {
throw std::runtime_error("expecting a non-nil predictor");
}
return predictor;
}
PredictorHandle
NewTensorRTPredictor(TensorRT_ModelFormat model_format, char **model_files,
TensorRT_DType model_datatype, char **input_layer_names,
int32_t num_input_layer_names, char **output_layer_names,
int32_t num_output_layer_names, int32_t batch_size) {
START_C_DEFINION();
// Create the builder
IBuilder *builder = createInferBuilder(gLogger);
if (builder == nullptr) {
std::string err =
std::string("cannot create tensorrt builder for ") + model_files[1];
throw std::runtime_error(err);
}
const bool isOnnxModel =
model_format == TensorRT_ModelFormat::TensorRT_OnnxFormat;
auto batchFlag =
(batch_size && !isOnnxModel)
? 0U
: 1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
INetworkDefinition *network = builder->createNetworkV2(batchFlag);
DataType blob_data_type = DataType::kFLOAT;
switch (model_datatype) {
case TensorRT_Byte:
blob_data_type = DataType::kINT8;
break;
case TensorRT_Char:
blob_data_type = DataType::kINT8;
break;
case TensorRT_Int:
blob_data_type = DataType::kINT32;
break;
case TensorRT_Half:
blob_data_type = DataType::kHALF;
break;
case TensorRT_Float:
blob_data_type = DataType::kFLOAT;
break;
default:
throw std::runtime_error("invalid model datatype");
}
// Parse the caffe model to populate the network, then set the outputs
// Create the parser according to the specified model format
std::vector<std::string> input_layer_names_vec{};
for (int ii = 0; ii < num_input_layer_names; ii++) {
input_layer_names_vec.emplace_back(input_layer_names[ii]);
}
std::vector<std::string> output_layer_names_vec{};
for (int ii = 0; ii < num_output_layer_names; ii++) {
output_layer_names_vec.emplace_back(output_layer_names[ii]);
}
if (model_format == TensorRT_CaffeFormat) {
auto parser = nvcaffeparser1::createCaffeParser();
if (parser == nullptr) {
std::string err =
std::string("cannot create tensorrt caffe parser for ") +
model_files[1];
throw std::runtime_error(err);
}
const IBlobNameToTensor *blobNameToTensor =
parser->parse(model_files[0], model_files[1], *network, blob_data_type);
for (int ii = 0; ii < num_output_layer_names; ii++) {
network->markOutput(*blobNameToTensor->find(output_layer_names[ii]));
}
} else if (model_format == TensorRT_OnnxFormat) {
// auto parser = nvonnxparser::createParser(*network, gLogger);
} else if (model_format == TensorRT_UffFormat) {
// auto parser = nvuffparser::createUffParser();
} else {
throw std::runtime_error("model format is not recognized");
}
builder->setMaxBatchSize(batch_size);
IBuilderConfig *builder_config = builder->createBuilderConfig();
builder_config->setMaxWorkspaceSize(36 << 20);
builder_config->setFlag(BuilderFlag::kGPU_FALLBACK);
if (blob_data_type == DataType::kINT8) {
builder_config->setFlag(BuilderFlag::kINT8);
}
if (blob_data_type == DataType::kHALF) {
builder_config->setFlag(BuilderFlag::kFP16);
}
ICudaEngine *engine =
builder->buildEngineWithConfig(*network, *builder_config);
network->destroy();
IHostMemory *trtModelStream = engine->serialize();
engine->destroy();
builder->destroy();
IRuntime *runtime = createInferRuntime(gLogger);
// Deserialize the engine
ICudaEngine *runtime_engine = runtime->deserializeCudaEngine(
trtModelStream->data(), trtModelStream->size(), nullptr);
IExecutionContext *context = runtime_engine->createExecutionContext();
trtModelStream->destroy();
auto predictor = new Predictor(context, input_layer_names_vec,
output_layer_names_vec, batch_size);
return (PredictorHandle)predictor;
END_C_DEFINION(nullptr);
}
void TenorRTPredictor_AddInput(PredictorHandle predictor_handle,
const char *name, TensorRT_DType dtype,
void *host_data, size_t num_elements) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
switch (dtype) {
#define DISPATCH_ADD_INPUT(DType, CType) \
case DType: \
predictor->AddInput<CType>(name, reinterpret_cast<CType *>(host_data), \
num_elements); \
break;
TensorRT_DType_Dispatch(DISPATCH_ADD_INPUT);
#undef DISPATCH_ADD_INPUT
default:
throw std::runtime_error("unexpected input type");
}
END_C_DEFINION();
}
void TenorRTPredictor_AddOutput(PredictorHandle predictor_handle,
const char *name, TensorRT_DType dtype) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
switch (dtype) {
#define DISPATCH_ADD_OUTPUT(DType, CType) \
case DType: \
predictor->AddOutput<CType>(name); \
break;
TensorRT_DType_Dispatch(DISPATCH_ADD_OUTPUT);
#undef DISPATCH_ADD_OUTPUT
default:
throw std::runtime_error("unexpected input type");
}
END_C_DEFINION();
}
void TenorRTPredictor_Synchronize(PredictorHandle predictor_handle) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
CHECK(predictor->synchronize());
END_C_DEFINION();
}
void TenorRTPredictor_Run(PredictorHandle predictor_handle) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
predictor->Run();
END_C_DEFINION();
}
int TenorRTPredictor_GetNumOutputs(PredictorHandle predictor_handle) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
return predictor->output_layer_names_.size();
END_C_DEFINION(-1);
}
void *TenorRTPredictor_GetOutput(PredictorHandle predictor_handle,
const char *name, int32_t *ndims,
int32_t **res_dims) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
auto dims = predictor->GetOutputShape(name);
void *data = predictor->GetOutputData(name);
*ndims = dims.size();
#ifdef DEBUG
std::cout << __LINE__ << " >>> "
<< "*ndims = " << *ndims << "\n";
#endif
*res_dims = (int32_t *)malloc(sizeof(int32_t) * (*ndims));
memcpy(*res_dims, dims.data(), sizeof(int32_t) * (*ndims));
return data;
END_C_DEFINION(nullptr);
}
bool TenorRTPredictor_HasError(PredictorHandle predictor_handle) {
return has_error;
}
const char *TenorRTPredictor_GetLastError(PredictorHandle predictor_handle) {
return error_string.c_str();
}
void TenorRTPredictor_Delete(PredictorHandle predictor_handle) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
if (predictor != nullptr) {
delete predictor;
}
END_C_DEFINION();
}
void TenorRTPredictor_StartProfiling(PredictorHandle predictor_handle,
const char *name, const char *metadata) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(predictor_handle);
if (name == nullptr) {
name = "";
}
if (metadata == nullptr) {
metadata = "";
}
if (predictor->prof_ == nullptr) {
predictor->prof_ = new profile(name, metadata);
} else {
predictor->prof_->reset();
}
END_C_DEFINION();
}
void TenorRTPredictor_EndProfiling(PredictorHandle pred) {
START_C_DEFINION();
auto predictor = get_predictor_from_handle(pred);
if (predictor->prof_) {
predictor->prof_->end();
}
END_C_DEFINION();
}
char *TenorRTPredictor_ReadProfiling(PredictorHandle pred) {
START_C_DEFINION();
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return strdup("");
}
if (predictor->prof_ == nullptr) {
return strdup("");
}
const auto s = predictor->prof_->read();
const auto cstr = s.c_str();
return strdup(cstr);
END_C_DEFINION(nullptr);
}
void TensoRT_Init() { initLibNvInferPlugins(&gLogger, ""); }
#endif // __linux__