forked from alibaba/MNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PoolingTf.cpp
74 lines (61 loc) · 1.93 KB
/
PoolingTf.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//
// PoolingTf.cpp
// MNNConverter
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "TfUtils.hpp"
#include "graph.pb.h"
#include "tfOpConverter.hpp"
DECLARE_OP_CONVERTER(PoolingTf);
MNN::OpType PoolingTf::opType() {
return MNN::OpType_Pooling;
}
MNN::OpParameter PoolingTf::type() {
return MNN::OpParameter_Pool;
}
// input: tensor
void PoolingTf::run(MNN::OpT *dstOp, TmpNode *srcNode, TmpGraph *tempGraph) {
auto pool = new MNN::PoolT;
tensorflow::AttrValue value;
int kernel_size_h = 1;
int kernel_size_w = 1;
int stride_h = 1;
int stride_w = 1;
if (srcNode->opType == "AvgPool") {
pool->type = MNN::PoolType_AVEPOOL;
} else if (srcNode->opType == "MaxPool") {
pool->type = MNN::PoolType_MAXPOOL;
} else {
DLOG(ERROR) << "Not Support This Pooling Type: " << srcNode->opType;
}
if (find_attr_value(srcNode->tfNode, "ksize", value)) {
kernel_size_h = value.list().i(1);
kernel_size_w = value.list().i(2);
}
pool->kernelX = kernel_size_w;
pool->kernelY = kernel_size_h;
if (find_attr_value(srcNode->tfNode, "strides", value)) {
stride_h = value.list().i(1);
stride_w = value.list().i(2);
}
pool->strideX = stride_w;
pool->strideY = stride_h;
if (find_attr_value(srcNode->tfNode, "padding", value)) {
if (value.s() == "VALID") {
pool->padType = MNN::PoolPadType_VALID;
} else if (value.s() == "SAME") {
pool->padType = MNN::PoolPadType_SAME;
} else {
DLOG(ERROR) << "Not Support This Padding Mode";
}
}
pool->padY = 0; // runtime compute this pad
pool->padX = 0;
pool->isGlobal = false; // TODO
dstOp->main.value = pool;
DCHECK(srcNode->inTensors.size() == 1) << "Pooling Input ERROR";
}
REGISTER_CONVERTER(PoolingTf, MaxPool);
REGISTER_CONVERTER(PoolingTf, AvgPool);