Skip to content

Commit

Permalink
Support ConstantOfShape in Caffe2 ONNX Backend (pytorch#16108)
Browse files Browse the repository at this point in the history
Summary:
This PR is the prerequisite to land pytorch#16095
Pull Request resolved: pytorch#16108

Reviewed By: BIT-silence

Differential Revision: D13725722

Pulled By: houseroad

fbshipit-source-id: 28c0fb72f075cd04f9db44dfab0163844c20c620
  • Loading branch information
houseroad authored and facebook-github-bot committed Jan 19, 2019
1 parent b436f94 commit daedec2
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 56 deletions.
170 changes: 126 additions & 44 deletions caffe2/onnx/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ Caffe2Backend::get_renamed_operators() const {
{"Unsqueeze", "ExpandDims"},
{"Tile", "NumpyTile"},
{"DynamicSlice", "Slice"},
{"ConstantOfShape", "ConstantFill"},
{"RandomNormal", "GaussianFill"}};
return kRenamedOperators;
}
Expand Down Expand Up @@ -340,6 +341,7 @@ Caffe2Backend::get_special_operators() const {
{"ArgMin", &Caffe2Backend::CreateArgMaxMin},
{"Cast", &Caffe2Backend::CreateCast},
{"Constant", &Caffe2Backend::CreateConstant},
{"ConstantOfShape", &Caffe2Backend::CreateConstantOfShape},
{"Conv", &Caffe2Backend::CreateConvPoolOpBase},
{"AveragePool", &Caffe2Backend::CreatePadPool},
{"GlobalAveragePool", &Caffe2Backend::CreatePadPool},
Expand Down Expand Up @@ -459,6 +461,20 @@ Caffe2Ops Caffe2Backend::CreateConstant(
return ret;
}

Caffe2Ops Caffe2Backend::CreateConstantOfShape(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
CAFFE_ENFORCE_EQ(onnx_node->node.input_size(), 1);
CAFFE_ENFORCE_EQ(onnx_node->node.output_size(), 1);

Caffe2Ops ret;
auto* c2_op = ret.ops.Add();
const auto* value = onnx_node->attributes.get<const TensorProto*>("value");
BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0));

return ret;
}

// Note [Caffe2 ConvPoolOpBase]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// To understand what is going on here, we have to talk a little bit about
Expand Down Expand Up @@ -1673,62 +1689,128 @@ void ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(caffe2::OperatorDe
void Caffe2Backend::BuildTensorFillingOp(
caffe2::OperatorDef* c2_op,
const TensorProto& onnx_tensor,
const std::string& name) {
auto fill_name = name.empty() ? onnx_tensor.name() : name;
const std::string& output_name,
const std::string& shape_name) {
auto fill_name = output_name.empty() ? onnx_tensor.name() : output_name;
CAFFE_ENFORCE(!fill_name.empty());

if (onnx_tensor.has_segment()) {
CAFFE_THROW("Currently not supporting loading segments.");
}

auto* c2_values = c2_op->add_arg();
c2_values->set_name("values");

if (onnx_tensor.data_type() == TensorProto::FLOAT) {
c2_op->set_type("GivenTensorFill");
auto* floats = c2_values->mutable_floats();
if (!TryConvertingTensorRawValues<float>(onnx_tensor, floats)) {
floats->CopyFrom(onnx_tensor.float_data());
}
} else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
c2_op->set_type("GivenTensorDoubleFill");
::google::protobuf::RepeatedField<double> tmp;
const ::google::protobuf::RepeatedField<double>* src = &tmp;
if (!TryConvertingTensorRawValues<double>(onnx_tensor, &tmp)) {
src = &onnx_tensor.double_data();
// if shape_name is empty, we generate GivenTensorFill
// otherwise, we generate ConstantFill, which accept shape as input
if (shape_name.empty()) {
// GivenTensor*Fill uses values
c2_values->set_name("values");
if (onnx_tensor.data_type() == TensorProto::FLOAT) {
c2_op->set_type("GivenTensorFill");
auto* floats = c2_values->mutable_floats();
if (!TryConvertingTensorRawValues<float>(onnx_tensor, floats)) {
floats->CopyFrom(onnx_tensor.float_data());
}
} else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
c2_op->set_type("GivenTensorDoubleFill");
::google::protobuf::RepeatedField<double> tmp;
const ::google::protobuf::RepeatedField<double>* src = &tmp;
if (!TryConvertingTensorRawValues<double>(onnx_tensor, &tmp)) {
src = &onnx_tensor.double_data();
}
for (const auto i : *src) {
c2_values->add_floats(i);
}
} else if (onnx_tensor.data_type() == TensorProto::INT64) {
ConvertIntegralValueToCaffe2<::google::protobuf::int64>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::UINT32) {
ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::BOOL) {
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::UINT8) {
ConvertIntegralValueToCaffe2<::google::protobuf::uint8>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::INT8) {
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::UINT16) {
ConvertIntegralValueToCaffe2<::google::protobuf::uint16>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::INT16) {
ConvertIntegralValueToCaffe2<::google::protobuf::int16>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::INT32) {
ConvertIntegralValueToCaffe2<::google::protobuf::int32>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::STRING) {
c2_op->set_type("GivenTensorStringFill");
auto* strings = c2_values->mutable_strings();
strings->CopyFrom(onnx_tensor.string_data());
} else {
CAFFE_THROW("unrecognized tensor type: ", onnx_tensor.data_type());
}
for (const auto i : *src) {
c2_values->add_floats(i);
auto* c2_shape = c2_op->add_arg();
c2_shape->set_name("shape");
for (const auto d : onnx_tensor.dims()) {
c2_shape->add_ints(d);
}
} else if (onnx_tensor.data_type() == TensorProto::INT64) {
ConvertIntegralValueToCaffe2<::google::protobuf::int64>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::UINT32) {
ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::BOOL) {
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::UINT8) {
ConvertIntegralValueToCaffe2<::google::protobuf::uint8>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::INT8) {
ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::UINT16) {
ConvertIntegralValueToCaffe2<::google::protobuf::uint16>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::INT16) {
ConvertIntegralValueToCaffe2<::google::protobuf::int16>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::INT32) {
ConvertIntegralValueToCaffe2<::google::protobuf::int32>(c2_op, c2_values, onnx_tensor);
} else if (onnx_tensor.data_type() == TensorProto::STRING) {
c2_op->set_type("GivenTensorStringFill");
auto* strings = c2_values->mutable_strings();
strings->CopyFrom(onnx_tensor.string_data());
} else {
CAFFE_THROW("unrecognized tensor type: ", onnx_tensor.data_type());
int value_size = 1;
for (const auto d : onnx_tensor.dims()) {
value_size *= d;
}
CAFFE_ENFORCE(value_size == 1);
auto c2_input_as_shape = c2_op->add_arg();
c2_input_as_shape->set_name("input_as_shape");
c2_input_as_shape->set_i(1);
c2_values->set_name("value");
auto* c2_dtype = c2_op->add_arg();
c2_dtype->set_name("dtype");
if (onnx_tensor.data_type() == TensorProto::FLOAT) {
c2_dtype->set_i(caffe2::TensorProto::FLOAT);
if (onnx_tensor.float_data_size() > 0) {
c2_values->set_f(onnx_tensor.float_data(0));
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(float));
float f;
memcpy(&f, &onnx_tensor.raw_data(), sizeof(float));
c2_values->set_f(f);
}
} else if (onnx_tensor.data_type() == TensorProto::DOUBLE){
c2_dtype->set_i(caffe2::TensorProto::DOUBLE);
if (onnx_tensor.double_data_size() > 0) {
c2_values->set_f(static_cast<float>(onnx_tensor.double_data(0)));
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(double));
double d;
memcpy(&d, &onnx_tensor.raw_data(), sizeof(double));
c2_values->set_f(static_cast<float>(d));
}
} else if (onnx_tensor.data_type() == TensorProto::INT64){
c2_dtype->set_i(caffe2::TensorProto::INT64);
if (onnx_tensor.int64_data_size() > 0) {
c2_values->set_i(onnx_tensor.int64_data(0));
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int64_t));
int64_t i;
memcpy(&i, &onnx_tensor.raw_data(), sizeof(int64_t));
c2_values->set_i(i);
}
} else if (onnx_tensor.data_type() == TensorProto::INT32){
c2_dtype->set_i(caffe2::TensorProto::INT32);
if (onnx_tensor.int32_data_size() > 0) {
c2_values->set_i(onnx_tensor.int32_data(0));
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int32_t));
int32_t i;
memcpy(&i, &onnx_tensor.raw_data(), sizeof(int32_t));
c2_values->set_i(i);
}
} else {
// TODO: to support more data type
std::stringstream oss;
oss << "Unsupported dtype: " << onnx_tensor.data_type();
CAFFE_THROW(oss.str());
}
// ConstantFill uses value
c2_op->set_type("ConstantFill");
c2_op->add_input(shape_name);
}

auto* c2_shape = c2_op->add_arg();
c2_shape->set_name("shape");
for (const auto d : onnx_tensor.dims()) {
c2_shape->add_ints(d);
}
c2_op->add_output(fill_name);
}

Expand Down
7 changes: 6 additions & 1 deletion caffe2/onnx/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ class CAFFE2_API Caffe2Backend {
void BuildTensorFillingOp(
caffe2::OperatorDef* c2_op,
const TensorProto& onnx_tensor,
const std::string& name = "");
const std::string& output_name = "",
const std::string& shape_name = "");

private:
using SpecialOpConverter =
Expand Down Expand Up @@ -192,6 +193,10 @@ class CAFFE2_API Caffe2Backend {

Caffe2Ops CreateConstant(OnnxNode* onnx_node, const ConversionContext& ctx);

Caffe2Ops CreateConstantOfShape(
OnnxNode* onnx_node,
const ConversionContext& ctx);

Caffe2Ops CreateConvPoolOpBase(
OnnxNode* onnx_node,
const ConversionContext& ctx);
Expand Down
34 changes: 25 additions & 9 deletions caffe2/operators/filler_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,31 @@ class FillerOp : public Operator<Context> {
if (InputSize()) {
auto shape = vector<int64_t>{};
if (input_as_shape_) {
// Shape input must be in CPU context
auto& input = this->template Input<Tensor>(0, CPU);
CAFFE_ENFORCE_EQ(
input.dim(),
1,
"When input_as_shape is true, the input must be a 1D tensor of "
"data type int64_t");
auto* shape_data = input.template data<int64_t>();
shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
if (this->InputIsTensorType(0, CPU)) {
// originally, shape input must be in CPU context
auto& input = this->template Input<Tensor>(0, CPU);
CAFFE_ENFORCE_EQ(
input.dim(),
1,
"When input_as_shape is true, the input must be a 1D tensor of "
"data type int64_t");
CAFFE_ENFORCE(input.numel() > 0);
auto* shape_data = input.template data<int64_t>();
shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
} else {
// in ONNX case, we allow shape to be in CUDA context
auto& input = Input(0);
CAFFE_ENFORCE_EQ(
input.dim(),
1,
"When input_as_shape is true, the input must be a 1D tensor of "
"data type int64_t");
CAFFE_ENFORCE(input.numel() > 0);
auto* shape_data = input.template data<int64_t>();
std::unique_ptr<int64_t[]> shape_data_copy = caffe2::make_unique<int64_t[]>(input.dim32(0));
context_.template CopyToCPU<int64_t>(input.dim32(0), shape_data, shape_data_copy.get());
shape.insert(shape.end(), shape_data_copy.get(), shape_data_copy.get() + input.dim32(0));
}
} else {
auto& input = Input(0);
shape.insert(shape.end(), input.sizes().begin(), input.sizes().end());
Expand Down
3 changes: 1 addition & 2 deletions caffe2/python/onnx/tests/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
'|test_convtranspose.*' # ConvTranspose needs some more complicated translation
'|test_mvn.*' # MeanVarianceNormalization is experimental and not supported.
'|test_dynamic_slice.*' # MeanVarianceNormalization is experimental and not supported.
'|test_constantlike.*' # Needs implementation
'|test_eyelike.*' # Needs implementation
'|test_maxunpool.*' # Needs implementation
'|test_acosh.*' # Needs implementation
Expand All @@ -51,7 +50,7 @@
'|test_scan.*' # Needs implementation
'|test_isnan.*' # Needs implementation
'|test_scatter.*' # Should be similar to ScatterAssign
'|test_constantofshape.*' # Needs implementation
'|test_constantofshape_int.*' # Needs implementation
'|test_where.*' # Needs implementation
'|test_shrink.*' # Needs implementation
')')
Expand Down

0 comments on commit daedec2

Please sign in to comment.