forked from dusty-nv/jetson-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
imageNet.h
99 lines (78 loc) · 2.77 KB
/
imageNet.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
/*
* http://github.com/dusty-nv/jetson-inference
*/
#ifndef __IMAGE_NET_H__
#define __IMAGE_NET_H__
#include "tensorNet.h"
/**
* Image recognition with GoogleNet/Alexnet or custom models, using TensorRT.
*/
class imageNet : public tensorNet
{
public:
/**
* Network choice enumeration.
*/
enum NetworkType
{
ALEXNET,
GOOGLENET
};
/**
* Load a new network instance
*/
static imageNet* Create( NetworkType networkType=GOOGLENET );
/**
* Load a new network instance
* @param prototxt_path File path to the deployable network prototxt
* @param model_path File path to the caffemodel
* @param mean_binary File path to the mean value binary proto
* @param class_info File path to list of class name labels
* @param input Name of the input layer blob.
*/
static imageNet* Create( const char* prototxt_path, const char* model_path, const char* mean_binary,
const char* class_labels, const char* input="data", const char* output="prob" );
/**
* Destroy
*/
virtual ~imageNet();
/**
* Determine the maximum likelihood image class.
* @param rgba float4 input image in CUDA device memory.
* @param width width of the input image in pixels.
* @param height height of the input image in pixels.
* @param confidence optional pointer to float filled with confidence value.
* @returns Index of the maximum class, or -1 on error.
*/
int Classify( float* rgba, uint32_t width, uint32_t height, float* confidence=NULL );
/**
* Retrieve the number of image recognition classes (typically 1000)
*/
inline uint32_t GetNumClasses() const { return mOutputClasses; }
/**
* Retrieve the description of a particular class.
*/
inline const char* GetClassDesc( uint32_t index ) const { return mClassDesc[index].c_str(); }
/**
* Retrieve the class synset category of a particular class.
*/
inline const char* GetClassSynset( uint32_t index ) const { return mClassSynset[index].c_str(); }
/**
* Retrieve the network type (alexnet or googlenet)
*/
inline NetworkType GetNetworkType() const { return mNetworkType; }
/**
* Retrieve a string describing the network name.
*/
inline const char* GetNetworkName() const { return (mNetworkType == GOOGLENET ? "googlenet" : "alexnet"); }
protected:
imageNet();
bool init( NetworkType networkType );
bool init(const char* prototxt_path, const char* model_path, const char* mean_binary, const char* class_path, const char* input, const char* output);
bool loadClassInfo( const char* filename );
uint32_t mOutputClasses;
std::vector<std::string> mClassSynset; // 1000 class ID's (ie n01580077, n04325704)
std::vector<std::string> mClassDesc;
NetworkType mNetworkType;
};
#endif