Skip to content

Commit

Permalink
generate attribute_completion_and_shape_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
okdshin committed Oct 9, 2018
1 parent 34638d1 commit 6c9c64c
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions menoh/attribute_completion_and_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ namespace menoh_impl {
auto graph = make_graph(model_data.node_list); // FIXME reorder nodes
model_data.node_list = graph.node_list();
for(auto& node : model_data.node_list) {
auto input = [&node](int i){
auto input = [&node](auto i){
return node.input_name_list.at(i);
};
auto output = [&node](int i){
auto output = [&node](auto i){
return node.output_name_list.at(i);
};

Expand Down Expand Up @@ -147,15 +147,19 @@ node.attribute_table.emplace(
{

auto count_include_pad = get<int>(node.attribute_table.at("count_include_pad"));
static_cast<void>(count_include_pad); // maybe unused


auto kernel_shape = get<ints>(node.attribute_table.at("kernel_shape"));
static_cast<void>(kernel_shape); // maybe unused


auto pads = get<ints>(node.attribute_table.at("pads"));
static_cast<void>(pads); // maybe unused


auto strides = get<ints>(node.attribute_table.at("strides"));
static_cast<void>(strides); // maybe unused


add_variable_to_table(output(0), dtype_of(input(0)),
Expand Down Expand Up @@ -207,12 +211,15 @@ node.attribute_table.emplace(
{

auto epsilon = get<float>(node.attribute_table.at("epsilon"));
static_cast<void>(epsilon); // maybe unused


auto momentum = get<float>(node.attribute_table.at("momentum"));
static_cast<void>(momentum); // maybe unused


auto spatial = get<int>(node.attribute_table.at("spatial"));
static_cast<void>(spatial); // maybe unused


assert(node.input_name_list.size() > 0);
Expand Down Expand Up @@ -240,10 +247,11 @@ assert(!"attribute not found: axis");
{

auto axis = get<int>(node.attribute_table.at("axis"));
static_cast<void>(axis); // maybe unused


auto output_dims = dims_of(input(0));
for(int i = 1; i < node.input_name_list.size(); ++i) {
for(unsigned int i = 1; i < node.input_name_list.size(); ++i) {
// TODO dim check
output_dims.at(axis) += dims_of(input(i)).at(axis);
}
Expand Down Expand Up @@ -319,18 +327,23 @@ node.attribute_table.emplace(
{

auto dilations = get<ints>(node.attribute_table.at("dilations"));
static_cast<void>(dilations); // maybe unused


auto group = get<int>(node.attribute_table.at("group"));
static_cast<void>(group); // maybe unused


auto kernel_shape = get<ints>(node.attribute_table.at("kernel_shape"));
static_cast<void>(kernel_shape); // maybe unused


auto pads = get<ints>(node.attribute_table.at("pads"));
static_cast<void>(pads); // maybe unused


auto strides = get<ints>(node.attribute_table.at("strides"));
static_cast<void>(strides); // maybe unused


add_variable_to_table(output(0), dtype_of(input(0)),
Expand Down Expand Up @@ -418,7 +431,7 @@ node.attribute_table.emplace(
ints input_size(input_profile.dims().begin()+2,
input_profile.dims().end());

for(int i = 0; i < kernel_ndims; ++i) {
for(unsigned int i = 0; i < kernel_ndims; ++i) {
auto total_padding = strides[i] * (input_size[i] - 1)
+ output_padding[i] + kernel_shape[i] - output_shape[i];
pads[i] = total_padding - (total_padding/2);
Expand All @@ -432,18 +445,23 @@ node.attribute_table.emplace(
{

auto dilations = get<ints>(node.attribute_table.at("dilations"));
static_cast<void>(dilations); // maybe unused


auto group = get<int>(node.attribute_table.at("group"));
static_cast<void>(group); // maybe unused


auto kernel_shape = get<ints>(node.attribute_table.at("kernel_shape"));
static_cast<void>(kernel_shape); // maybe unused


auto output_padding = get<ints>(node.attribute_table.at("output_padding"));
static_cast<void>(output_padding); // maybe unused


auto strides = get<ints>(node.attribute_table.at("strides"));
static_cast<void>(strides); // maybe unused


add_variable_to_table(output(0), dtype_of(input(0)),
Expand Down Expand Up @@ -473,6 +491,7 @@ node.attribute_table.emplace(
{

auto alpha = get<float>(node.attribute_table.at("alpha"));
static_cast<void>(alpha); // maybe unused


assert(node.input_name_list.size() > 0);
Expand Down Expand Up @@ -550,15 +569,19 @@ node.attribute_table.emplace(
{

auto alpha = get<float>(node.attribute_table.at("alpha"));
static_cast<void>(alpha); // maybe unused


auto beta = get<float>(node.attribute_table.at("beta"));
static_cast<void>(beta); // maybe unused


auto transA = get<int>(node.attribute_table.at("transA"));
static_cast<void>(transA); // maybe unused


auto transB = get<int>(node.attribute_table.at("transB"));
static_cast<void>(transB); // maybe unused


auto a_dims = dims_of(input(0));
Expand Down Expand Up @@ -603,6 +626,7 @@ node.attribute_table.emplace(
{

auto alpha = get<float>(node.attribute_table.at("alpha"));
static_cast<void>(alpha); // maybe unused


assert(node.input_name_list.size() > 0);
Expand Down Expand Up @@ -663,15 +687,19 @@ assert(!"attribute not found: size");
{

auto alpha = get<float>(node.attribute_table.at("alpha"));
static_cast<void>(alpha); // maybe unused


auto beta = get<float>(node.attribute_table.at("beta"));
static_cast<void>(beta); // maybe unused


auto bias = get<float>(node.attribute_table.at("bias"));
static_cast<void>(bias); // maybe unused


auto size = get<float>(node.attribute_table.at("size"));
static_cast<void>(size); // maybe unused


assert(node.input_name_list.size() > 0);
Expand Down Expand Up @@ -732,15 +760,19 @@ node.attribute_table.emplace(
{

auto kernel_shape = get<ints>(node.attribute_table.at("kernel_shape"));
static_cast<void>(kernel_shape); // maybe unused


auto pads = get<ints>(node.attribute_table.at("pads"));
static_cast<void>(pads); // maybe unused


auto storage_order = get<int>(node.attribute_table.at("storage_order"));
static_cast<void>(storage_order); // maybe unused


auto strides = get<ints>(node.attribute_table.at("strides"));
static_cast<void>(strides); // maybe unused


add_variable_to_table(output(0), dtype_of(input(0)),
Expand Down Expand Up @@ -786,6 +818,7 @@ node.attribute_table.emplace(
{

auto axis = get<int>(node.attribute_table.at("axis"));
static_cast<void>(axis); // maybe unused


assert(node.input_name_list.size() > 0);
Expand Down Expand Up @@ -848,7 +881,7 @@ else
if(node.op_type == "Transpose") {

ints perm(ndims_of(input(0)));
for(int i = 0; i < perm.size(); ++i) {{
for(unsigned int i = 0; i < perm.size(); ++i) {{
perm.at(i) = perm.size()-i-1;
}}

Expand All @@ -867,11 +900,12 @@ node.attribute_table.emplace(
{

auto perm = get<ints>(node.attribute_table.at("perm"));
static_cast<void>(perm); // maybe unused


auto input_dims = dims_of(input(0));
ints output_dims(input_dims.size());
for(int i = 0; i < input_dims.size(); ++i) {
for(unsigned int i = 0; i < input_dims.size(); ++i) {
output_dims.at(i) = input_dims.at(perm.at(i));
}
add_variable_to_table(output(0), dtype_of(input(0)), output_dims);
Expand Down

0 comments on commit 6c9c64c

Please sign in to comment.