Skip to content

Commit

Permalink
Pass IValue from c10 dispatcher to caffe2 operator (pytorch#16065)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#16065

Before, we registered the caffe2 kernel with the c10 dispatcher using plain C types.
Now, we pass in IValues, which avoids the unwrapping inbetween.

Reviewed By: ezyang

Differential Revision: D13689036

fbshipit-source-id: b976a2c46a5a541f6a926b3df255e8a535e32420
  • Loading branch information
smessmer authored and facebook-github-bot committed Jan 19, 2019
1 parent c904416 commit e8b872a
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions caffe2/operators/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,15 @@ to the end.)
// Register layer norm with c10
namespace {
template <class DataType>
void layer_norm_c10(
const at::Tensor& X_,
const at::Tensor& Y_,
const at::Tensor& mean_,
const at::Tensor& sig_,
int axis,
float epsilon,
c10::intrusive_ptr<caffe2::Blob> cache_) {
caffe2::Tensor X{c10::C10Tensor(X_)};
caffe2::Tensor Y{c10::C10Tensor(Y_)};
caffe2::Tensor mean{c10::C10Tensor(mean_)};
caffe2::Tensor sig{c10::C10Tensor(sig_)};
c10::IValue layer_norm_c10(c10::ArrayRef<c10::IValue> inputs) {
caffe2::Tensor X{c10::C10Tensor(inputs[0].toTensor())};
caffe2::Tensor Y{c10::C10Tensor(inputs[1].toTensor())};
caffe2::Tensor mean{c10::C10Tensor(inputs[2].toTensor())};
caffe2::Tensor sig{c10::C10Tensor(inputs[3].toTensor())};
int64_t axis = inputs[4].toInt();
float epsilon = inputs[5].toDouble();
caffe2::CPUContext context;
c10::core::opschema::LayerNorm::Cache* cache = cache_->GetMutable<c10::core::opschema::LayerNorm::Cache>();
c10::core::opschema::LayerNorm::Cache* cache = inputs[6].toBlob()->GetMutable<c10::core::opschema::LayerNorm::Cache>();
if (!cache->scale.has_value()) {
cache->scale = at::Tensor(c10::C10Tensor(caffe2::Tensor{caffe2::CPU}));
}
Expand All @@ -219,6 +214,7 @@ void layer_norm_c10(
caffe2::LayerNormOp<caffe2::CPUContext>::runLayerNorm<DataType>(
X, &Y, &mean, &sig, canonical_axis, epsilon, &scale, &bias, static_cast<caffe2::CPUContext*>(&context)
);
return c10::IValue();
}
}
namespace c10 {
Expand Down

0 comments on commit e8b872a

Please sign in to comment.