From c9044166a5bf449dfbb56052235d6266a3abb5a1 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Fri, 18 Jan 2019 15:55:57 -0800 Subject: [PATCH] Make c10 dispatcher use boxed kernel function pointers (#16051) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16051 This changes the kernels stored in the c10 dispatcher from plain C function pointers to IValue-based KernelFunction*. Note that KernelFunction is currently taking an `ArrayRef` as arguments. A later diff will change that to it taking a `Stack*`. Reviewed By: ezyang Differential Revision: D13684518 fbshipit-source-id: 1fa54f60cec2e967b92a4a043d6e3ac1627ed991 --- aten/src/ATen/core/dispatch/DispatchTable.h | 38 +++-- aten/src/ATen/core/dispatch/Dispatcher.h | 32 ++--- .../ATen/core/dispatch/KernelRegistration.h | 30 ++-- aten/src/ATen/core/dispatch/OpSchema.h | 133 +++++++++++++++--- aten/src/ATen/core/dispatch/OpSchema_test.cpp | 12 +- aten/src/ATen/core/opschema/layer_norm.cpp | 4 + aten/src/ATen/core/opschema/layer_norm.h | 17 +-- caffe2/core/operator_c10wrapper.h | 59 ++++---- .../operators/experimental/c10/cpu/add_cpu.cc | 12 +- .../experimental/c10/cpu/averaged_loss_cpu.cc | 11 +- .../experimental/c10/cpu/batch_gather_cpu.cc | 12 +- .../experimental/c10/cpu/batch_matmul_cpu.cc | 15 +- .../experimental/c10/cpu/cast_cpu.cc | 11 +- .../experimental/c10/cpu/concat_cpu.cc | 10 +- .../c10/cpu/enforce_finite_cpu.cc | 4 +- .../experimental/c10/cpu/expand_dims_cpu.cc | 15 +- .../operators/experimental/c10/cpu/fc_cpu.cc | 31 ++-- .../experimental/c10/cpu/filler_cpu.cc | 56 ++++---- .../experimental/c10/cpu/flatten_cpu.cc | 8 +- .../operators/experimental/c10/cpu/mul_cpu.cc | 12 +- .../experimental/c10/cpu/relu_cpu.cc | 8 +- .../experimental/c10/cpu/sigmoid_cpu.cc | 8 +- .../sigmoid_cross_entropy_with_logits_cpu.cc | 12 +- .../c10/cpu/sparse_lengths_sum_cpu.cc | 16 +-- .../experimental/c10/cpu/stop_gradient_cpu.cc | 8 +- .../operators/experimental/c10/schemas/add.h | 8 +- .../experimental/c10/schemas/averaged_loss.cc | 3 + .../experimental/c10/schemas/averaged_loss.h | 11 +- .../experimental/c10/schemas/batch_gather.h | 8 +- .../experimental/c10/schemas/batch_matmul.cc | 3 + .../experimental/c10/schemas/batch_matmul.h | 13 +- .../operators/experimental/c10/schemas/cast.h | 8 +- .../experimental/c10/schemas/concat.h | 15 +- .../experimental/c10/schemas/enforce_finite.h | 4 +- .../experimental/c10/schemas/expand_dims.cc | 11 +- .../experimental/c10/schemas/expand_dims.h | 14 +- .../operators/experimental/c10/schemas/fc.cc | 5 +- .../operators/experimental/c10/schemas/fc.h | 17 +-- .../experimental/c10/schemas/filler.cc | 38 ++--- .../experimental/c10/schemas/filler.h | 60 +++----- .../experimental/c10/schemas/flatten.h | 6 +- .../operators/experimental/c10/schemas/mul.h | 8 +- .../operators/experimental/c10/schemas/relu.h | 4 +- .../experimental/c10/schemas/sigmoid.h | 4 +- .../sigmoid_cross_entropy_with_logits.h | 8 +- .../c10/schemas/sparse_lengths_sum.h | 10 +- .../experimental/c10/schemas/stop_gradient.h | 6 +- caffe2/operators/layer_norm_op.cc | 24 ++-- torch/csrc/jit/c10_ops/layer_norm.cpp | 29 +++- 49 files changed, 501 insertions(+), 390 deletions(-) diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 77d782e543d7df..39c850a14fcec6 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -20,9 +21,9 @@ template class ThreadsafeOperatorTable_ final { public: template - void emplace(Key_&& key, void* value) { - bool res = map_.write([&](ska::flat_hash_map& map) -> bool { - auto result = map.emplace(std::forward(key), value); + void emplace(Key_&& key, KernelFunction value) { + bool res = map_.write([&](ska::flat_hash_map& map) -> bool { + auto result = map.emplace(std::forward(key), std::move(value)); return result.second; }); if (!res) { @@ -34,7 +35,7 @@ class ThreadsafeOperatorTable_ final { void erase(const Key& key) { auto num_removed = - map_.write([&](ska::flat_hash_map& map) -> size_t { + map_.write([&](ska::flat_hash_map& map) -> size_t { return map.erase(key); }); assert(num_removed <= 1); // This is not a multi-map @@ -44,11 +45,11 @@ class ThreadsafeOperatorTable_ final { } } - void* lookup(const Key& key) const { - return map_.read([&](const ska::flat_hash_map& map) -> void* { + const KernelFunction* lookup(const Key& key) const { + return map_.read([&](const ska::flat_hash_map& map) -> const KernelFunction* { auto found = map.find(key); if (found != map.end()) { - return found->second; + return &found->second; } else { return nullptr; } @@ -56,7 +57,7 @@ class ThreadsafeOperatorTable_ final { } private: - LeftRight> map_; + LeftRight> map_; }; } // namespace details @@ -86,9 +87,9 @@ class DispatchTable final { * @param dispatch_key Dispatch key to define when this kernel is selected */ void registerKernel( - typename Schema::signature::func_type* func, + KernelFunction func, typename Schema::dispatch::dispatch_key_type dispatch_key) { - kernels_.emplace(std::move(dispatch_key), reinterpret_cast(func)); + kernels_.emplace(std::move(dispatch_key), std::move(func)); } /** @@ -111,23 +112,20 @@ class DispatchTable final { * @param args Arguments to invoke the function with * @return Returned value of the operator */ - template - typename Schema::signature::return_type call(Args&&... args) const { + IValue call(ArrayRef args) const { // TODO Better error message, but need to take care that reference arguments // match non-reference arguments and so on. // static_assert(std::is_same::value, "Argument types don't match // operator signature"); - auto kernel_func = lookupKernelFunc_(args...); - return kernel_func(std::forward(args)...); + const auto& kernel_func = lookupKernelFunc_(args); + return kernel_func(args); } private: - template - typename Schema::signature::func_type* lookupKernelFunc_( - const Args&... args) const { - auto dispatch_key = Schema::dispatch::dispatch_key(args...); - void* found = kernels_.lookup(dispatch_key); + const KernelFunction& lookupKernelFunc_(ArrayRef args) const { + auto dispatch_key = Schema::dispatch::dispatch_key(args); + const KernelFunction* found = kernels_.lookup(dispatch_key); if (found == nullptr) { // TODO Better error message - include op name and dispatch key (i.e. // argument types) @@ -135,7 +133,7 @@ class DispatchTable final { std::string() + "Didn't find kernel to dispatch to for operator '" + Schema::metadata::name() + "'"); } - return reinterpret_cast(found); + return *found; } details::ThreadsafeOperatorTable_< diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 78eb10576ab9f7..7873db08281e74 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -9,6 +9,8 @@ namespace c10 { */ template class Dispatcher final { +private: + using Schema = OpSchema; public: // Implementation note: this class abstracts over the fact that we have per-operator // dispatch tables. This could be easily adjusted to have a single global hash @@ -16,44 +18,26 @@ class Dispatcher final { /** * Register an operator to the dispatch table for some operator schema. - * - * @tparam OpSchemaDef Operator schema to register this operator to (mandatory) - * @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp (inferred) - * @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp - * @return void */ - template - static void registerKernel(Args&&... args) { + static void registerKernel(KernelFunction kernel_func, typename Schema::dispatch::dispatch_key_type dispatch_key) { auto& dispatch_table_for_this_op = c10_dispatch_table(); - return dispatch_table_for_this_op.registerKernel(std::forward(args)...); + return dispatch_table_for_this_op.registerKernel(std::move(kernel_func), std::move(dispatch_key)); } /** * Remove an operator from the dispatch table for some operator schema. - * - * @tparam OpSchemaDef Operator schema to deregister from (mandatory) - * @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp (inferred) - * @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp - * @return void */ - template - static void deregisterKernel(Args&&... args) { + static void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) { auto& dispatch_table_for_this_op = c10_dispatch_table(); - return dispatch_table_for_this_op.deregisterKernel(std::forward(args)...); + return dispatch_table_for_this_op.deregisterKernel(dispatch_key); } /** * Perform a dynamic dispatch to some operator - * - * @tparam OpSchemaDef Operator schema to dispatch with (mandatory) - * @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call (inferred) - * @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call - * @return Return type of this operator */ - template - static typename OpSchema::signature::return_type call(Args&&... args) { + static IValue call(ArrayRef args) { auto& dispatch_table_for_this_op = c10_dispatch_table(); - return dispatch_table_for_this_op.call(std::forward(args)...); + return dispatch_table_for_this_op.call(args); } }; diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h index 6141a5f407eab0..13aa17ac528775 100644 --- a/aten/src/ATen/core/dispatch/KernelRegistration.h +++ b/aten/src/ATen/core/dispatch/KernelRegistration.h @@ -34,7 +34,7 @@ class KernelRegistrar final { * @param kernel The concrete function implementation to register * @param dispatch_key The dispatch key to register the function to */ - KernelRegistrar(typename Schema::signature::func_type* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key) + KernelRegistrar(KernelFunction kernel, typename Schema::dispatch::dispatch_key_type dispatch_key) : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { Dispatcher::registerKernel(kernel, dispatch_key_); } @@ -78,8 +78,7 @@ class KernelRegistrar final { * The resulting full expression is implicitly convertible to a KernelRegistrar. * * @tparam OpSchemaDef The operator schema this is building a KernelRegistration for - * @tparam hasKernel Boolean for compile-time checking that a kernel is specified before finalizing the builder - * @tparam hasDispatchKey Boolean for compile-time checking thhat a dispatch key is specified before finalizing the builder + * @tparam FieldsPresentFlags Remembers which fields are already set in the builder */ template class KernelRegistrationBuilder final { @@ -89,15 +88,15 @@ class KernelRegistrationBuilder final { static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0; static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1; - c10::optional kernel_; + c10::optional kernel_; c10::optional dispatch_key_; public: - constexpr KernelRegistrationBuilder() + KernelRegistrationBuilder() : KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {} - constexpr KernelRegistrationBuilder( - c10::optional kernel, + KernelRegistrationBuilder( + c10::optional kernel, c10::optional dispatch_key) : kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {} @@ -106,7 +105,7 @@ class KernelRegistrationBuilder final { * creates the object. * @return Produced KernelRegistrar */ - constexpr operator KernelRegistrar() && { + operator KernelRegistrar() && { static_assert(FieldsPresentFlags & KERNEL_PRESENT, "Forgot to call .kernel() in kernel registration"); static_assert(FieldsPresentFlags & DISPATCH_KEY_PRESENT, "Forgot to call .dispatchKey() in kernel registration"); return KernelRegistrar(std::move(*kernel_), std::move(*dispatch_key_)); @@ -117,9 +116,18 @@ class KernelRegistrationBuilder final { * @param kernel concrete function implementation to be registered * @return "this" for method chaining */ - constexpr KernelRegistrationBuilder kernel(typename Schema::signature::func_type* kernel_func) && { + KernelRegistrationBuilder kernel(KernelFunction kernel_func) && { static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration"); - return KernelRegistrationBuilder(*kernel_func, std::move(dispatch_key_)); + return KernelRegistrationBuilder(std::move(kernel_func), std::move(dispatch_key_)); + } + + /** + * Specify the concrete function implementation for this dispatch registration + * @param kernel concrete function implementation to be registered + * @return "this" for method chaining + */ + KernelRegistrationBuilder kernel(typename Schema::signature::func_type* kernel_func) && { + return std::move(*this).kernel(Schema::signature::wrap_kernel(kernel_func)); } /** @@ -127,7 +135,7 @@ class KernelRegistrationBuilder final { * @param dispatch_key dispatch key to register the function to * @return "this" for method chaining */ - constexpr KernelRegistrationBuilder dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && { + KernelRegistrationBuilder dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && { static_assert(!(FieldsPresentFlags & DISPATCH_KEY_PRESENT), "Tried to define kernel twice in same op registration"); return KernelRegistrationBuilder(std::move(kernel_), std::move(dispatch_key)); } diff --git a/aten/src/ATen/core/dispatch/OpSchema.h b/aten/src/ATen/core/dispatch/OpSchema.h index ad2c7672bd3457..db6f3e722d4247 100644 --- a/aten/src/ATen/core/dispatch/OpSchema.h +++ b/aten/src/ATen/core/dispatch/OpSchema.h @@ -1,13 +1,18 @@ #pragma once #include +#include #include #include +#include #include -#include +#include namespace c10 { +// TODO Use folly::Function for perf +using KernelFunction = std::function)>; + namespace details { /** @@ -16,7 +21,7 @@ namespace details { */ template using is_tensor_arg = std:: - is_same>>; + is_same>>; inline DeviceTypeId to_device_type_id(DeviceType device_type) { switch (device_type) { @@ -29,17 +34,40 @@ inline DeviceTypeId to_device_type_id(DeviceType device_type) { } } -inline TensorParameterDispatchKey tensor_to_dispatch_key(const C10Tensor& tensor) { +inline TensorParameterDispatchKey tensor_to_dispatch_key(const at::Tensor& tensor) { return TensorParameterDispatchKey{ - to_device_type_id(tensor.impl()->device_type()), + to_device_type_id(tensor.device().type()), LayoutId(0), - tensor.impl()->dtype().id()}; + tensor.dtype().id()}; +} + +template struct get_ith_tensor_arg_ { + static_assert(!std::is_same::value, "Index out of bounds"); +}; +template +struct get_ith_tensor_arg_, guts::enable_if_t::value>> { + static at::Tensor call(ArrayRef args) { + if (!args[offset].isTensor()) { + throw std::runtime_error("Expected argument " + guts::to_string(offset) + " to be of type Tensor but found different type."); + } + return args[offset].toTensor(); + } +}; +template +struct get_ith_tensor_arg_, guts::enable_if_t::value>> { + static at::Tensor call(ArrayRef args) { + return get_ith_tensor_arg_<(is_tensor_arg::value ? (index-1) : index), offset + 1, guts::typelist::typelist>::call(args); + } +}; +template at::Tensor get_ith_tensor_arg(ArrayRef args) { + return get_ith_tensor_arg_::call(args); } // Extract type ids for all tensors from an array of tensors -template -guts::array getDispatchTypeIds__(const guts::array& tensor_args, guts::index_sequence) { - return {tensor_to_dispatch_key(*tensor_args[indices])...}; +template +guts::array getDispatchTypeIds__(ArrayRef args, guts::index_sequence) { + using ParameterTypes = typename guts::function_traits::parameter_types; + return {tensor_to_dispatch_key(get_ith_tensor_arg(args))...}; } /** @@ -49,10 +77,9 @@ guts::array getDispatchTypeIds__( * @param args List of arguments to get type ids from * @return guts::array, where n is the number of tensor arguments (is_tensor_arg) in the class */ -template -guts::array getDispatchTypeIds_(const Args&... args) { - auto tensor_args = guts::filter_map([] (const C10Tensor& v){return &v;}, args...); - return getDispatchTypeIds__(tensor_args, guts::make_index_sequence()); +template +guts::array getDispatchTypeIds_(ArrayRef args) { + return getDispatchTypeIds__(args, guts::make_index_sequence()); } // TODO Test getDispatchTypeIds_ @@ -88,6 +115,63 @@ struct has_name_defined +struct ivalue_to_arg_type { + static T call(const IValue& v) { + return std::move(v).to(); + } +}; +template +struct ivalue_to_arg_type> { + static ArrayRef call(const IValue& v) { + return v.to>>()->elements(); + } +}; + +template struct _wrapKernel {}; +template struct _wrapKernel, FuncType> { + using parameter_types = guts::typelist::typelist; + + template + static KernelFunction call(FuncType* kernel, guts::index_sequence) { + return [kernel] (ArrayRef args) -> IValue { + if (args.size() != sizeof...(ParamTypes)) { + throw std::runtime_error("Wrong number of arguments for operator call"); + } + return return_type_to_ivalue( + (*kernel)(ivalue_to_arg_type>>>::call(args[indices])...) + ); + }; + } +}; +template struct _wrapKernel, FuncType> { + using parameter_types = guts::typelist::typelist; + + template + static KernelFunction call(FuncType* kernel, guts::index_sequence) { + return [kernel] (ArrayRef args) -> IValue { + if (args.size() != sizeof...(ParamTypes)) { + throw std::runtime_error("Wrong number of arguments for operator call"); + } + (*kernel)(ivalue_to_arg_type>>>::call(args[indices])...); + return IValue(); + }; + } +}; + +template +KernelFunction wrapKernel(typename SignatureTraits::func_type* kernel) { + using return_type = typename SignatureTraits::return_type; + using parameter_types = typename SignatureTraits::parameter_types; + using func_type = typename SignatureTraits::func_type; + constexpr size_t num_parameters = guts::typelist::size::value; + + return _wrapKernel::call( + kernel, + guts::make_index_sequence() + ); +} + /** * Wrapper class around a user-provided schema definition some useful information about the schema. * @@ -123,6 +207,10 @@ template class OpSignatureSchema final { static constexpr size_t num_outputs = OpSchemaDef::num_outputs(); + static KernelFunction wrap_kernel(func_type* kernel) { + return details::wrapKernel(kernel); + } + private: static_assert(details::has_parameter_names_defined::value, "Operator schema doesn't define parameter_names member."); // TODO Allow simpler definition of parameter_names without having to spell out the guts::array type in the schema def. @@ -169,16 +257,16 @@ class OpDispatchKeySchema; - template - static inline dispatch_key_type dispatch_key(const Args&... args) { + static inline dispatch_key_type dispatch_key(ArrayRef args) { + /* TODO Should we make this a runtime assert now? using guts::typelist::map_t; using guts::typelist::typelist; static_assert(std::is_same< map_t>>, map_t> - >::value, "Invalid argument types passed to OpSchema::dispatch_key()"); + >::value, "Invalid argument types passed to OpSchema::dispatch_key()");*/ return dispatch_key_type { - details::getDispatchTypeIds_(args...) + details::getDispatchTypeIds_(args) }; } }; @@ -201,21 +289,22 @@ class OpDispatchKeySchema::value, "Operator schema specified custom dispatch_key() derivation function, but the returned dispatch key type doesn't have an overload for std::hash. Please define it."); static_assert(std::is_same< - guts::typelist::map_t>, - guts::typelist::map_t> - >::value, "Operator schema defines custom dispatch_key() derivation function, but the arguments don't match the operator signature."); + guts::typelist::typelist>, + typename dispatch_key_traits::parameter_types + >::value, "Operator schema defines custom dispatch_key() derivation function, but it has the wrong signature. Expected to take one argument, which is of type ArrayRef."); public: - template - static inline dispatch_key_type dispatch_key(const Args&... args) { + static inline dispatch_key_type dispatch_key(ArrayRef args) { + /* TODO Should we make this a runtime assert now? using guts::typelist::map_t; using guts::typelist::typelist; static_assert(std::is_same< map_t>>, map_t> >::value, "Invalid argument types passed to OpSchema::dispatch_key()"); - return OpSchemaDef::dispatch_key(args...); + */ + return OpSchemaDef::dispatch_key(args); } }; diff --git a/aten/src/ATen/core/dispatch/OpSchema_test.cpp b/aten/src/ATen/core/dispatch/OpSchema_test.cpp index f03254722918d8..1f5f16a9f417fa 100644 --- a/aten/src/ATen/core/dispatch/OpSchema_test.cpp +++ b/aten/src/ATen/core/dispatch/OpSchema_test.cpp @@ -1,15 +1,17 @@ #include #include +#include using namespace c10; +using at::Tensor; -static_assert(details::is_tensor_arg::value, ""); -static_assert(details::is_tensor_arg::value, ""); -static_assert(details::is_tensor_arg::value, ""); +static_assert(details::is_tensor_arg::value, ""); +static_assert(details::is_tensor_arg::value, ""); +static_assert(details::is_tensor_arg::value, ""); static_assert(!details::is_tensor_arg::value, ""); struct SchemaDef final { - using Signature = bool(int, C10Tensor, float, C10Tensor, C10Tensor, unsigned int); + using Signature = bool(int, Tensor, float, Tensor, Tensor, unsigned int); static constexpr guts::array parameter_names = {{ "1", "2", "3", "4", "5", "6" }}; @@ -22,6 +24,6 @@ static_assert(std::is_same::signature::return static_assert( std::is_same< guts::typelist:: - typelist, + typelist, typename OpSchema::signature::parameter_types>::value, ""); diff --git a/aten/src/ATen/core/opschema/layer_norm.cpp b/aten/src/ATen/core/opschema/layer_norm.cpp index be908a58fc7fec..ca25c95ecf58f7 100644 --- a/aten/src/ATen/core/opschema/layer_norm.cpp +++ b/aten/src/ATen/core/opschema/layer_norm.cpp @@ -2,3 +2,7 @@ #include C10_DEFINE_OP_SCHEMA(c10::core::opschema::LayerNorm); + +namespace caffe2 { +CAFFE_KNOWN_TYPE(c10::core::opschema::LayerNorm::Cache); +} diff --git a/aten/src/ATen/core/opschema/layer_norm.h b/aten/src/ATen/core/opschema/layer_norm.h index d80c9650b32149..f0830b6cf69834 100644 --- a/aten/src/ATen/core/opschema/layer_norm.h +++ b/aten/src/ATen/core/opschema/layer_norm.h @@ -1,7 +1,8 @@ #pragma once -#include +#include #include +#include namespace c10 { namespace core { @@ -14,18 +15,18 @@ struct LayerNorm final { static constexpr const char* name = "LayerNorm"; struct Cache final { - at::optional scale = at::nullopt; - at::optional bias = at::nullopt; + at::optional scale = at::nullopt; + at::optional bias = at::nullopt; }; using Signature = void( - const C10Tensor& input, - const C10Tensor& output, - const C10Tensor& output_mean, - const C10Tensor& output_stddev, + const at::Tensor& input, + const at::Tensor& output, + const at::Tensor& output_mean, + const at::Tensor& output_stddev, int axis, float epsilon, - Cache* cache); + intrusive_ptr cache); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/core/operator_c10wrapper.h b/caffe2/core/operator_c10wrapper.h index 44591871697864..0e9e2e24c5fdef 100644 --- a/caffe2/core/operator_c10wrapper.h +++ b/caffe2/core/operator_c10wrapper.h @@ -4,20 +4,13 @@ #include "caffe2/core/operator.h" #include #include +#include namespace caffe2 { namespace details { template struct true_t : std::true_type {}; -template -inline std::shared_ptr init_state() { - return std::make_shared(); -} -template <> -inline std::shared_ptr init_state() { - return std::shared_ptr(); -} template using is_output_arg = std::is_same; template @@ -60,7 +53,7 @@ class C10OperatorWrapper final : public Operator { C10OperatorWrapper(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), - state_(details::init_state()), + state_(make_intrusive()), parameters_(parse_parameters_( operator_def, c10::guts::make_index_sequence())) {} @@ -115,11 +108,14 @@ class C10OperatorWrapper final : public Operator { c10::guts::index_sequence, c10::guts::index_sequence, c10::guts::index_sequence) { + state_->GetMutable(); // initialize state if not initialized yet c10::Dispatcher::call( - C10Tensor(Input(InputIndex))..., - C10Tensor(*Output(OutputIndex))..., - std::get(parameters_)..., - state_.get()); + ArrayRef{ + IValue(at::Tensor(C10Tensor(Input(InputIndex))))..., + IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))..., + IValue(std::get(parameters_))..., + IValue(state_) + }); } template < @@ -135,9 +131,12 @@ class C10OperatorWrapper final : public Operator { c10::guts::index_sequence, c10::guts::index_sequence) { c10::Dispatcher::call( - C10Tensor(Input(InputIndex))..., - C10Tensor(*Output(OutputIndex))..., - std::get(parameters_)...); + // TODO Make outputs be returned, not passed in + ArrayRef{ + IValue(at::Tensor(C10Tensor(Input(InputIndex))))..., + IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))..., + IValue(std::get(parameters_))... + }); } template < @@ -152,11 +151,15 @@ class C10OperatorWrapper final : public Operator { c10::guts::index_sequence, c10::guts::index_sequence, c10::guts::index_sequence) { + state_->GetMutable(); // initialize state if not initialized yet c10::Dispatcher::call( - at::ArrayRef(array_inputs_()), - C10Tensor(*Output(OutputIndex))..., - std::get(parameters_)..., - state_.get()); + // TODO Make outputs be returned, not passed in + ArrayRef{ + IValue(at::ArrayRef(array_inputs_())), + IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))..., + IValue(std::get(parameters_))..., + IValue(state_) + }); } template < @@ -172,21 +175,23 @@ class C10OperatorWrapper final : public Operator { c10::guts::index_sequence, c10::guts::index_sequence) { c10::Dispatcher::call( - at::ArrayRef(array_inputs_()), - C10Tensor(*Output(OutputIndex))..., - std::get(parameters_)...); + ArrayRef{ + IValue(ivalue::TensorList(array_inputs_())), + IValue(at::Tensor(C10Tensor(*Output(OutputIndex))))..., + IValue(std::get(parameters_))... + }); } - std::vector array_inputs_() { - std::vector result; + std::vector array_inputs_() { + std::vector result; result.reserve(InputSize()); for (size_t i = 0; i < InputSize(); ++i) { - result.push_back(C10Tensor(Input(i))); + result.push_back(at::Tensor(c10::C10Tensor(Input(i)))); } return result; } - std::shared_ptr state_; + intrusive_ptr state_; ParameterTuple parameters_; }; diff --git a/caffe2/operators/experimental/c10/cpu/add_cpu.cc b/caffe2/operators/experimental/c10/cpu/add_cpu.cc index b41cf0906f0ec2..dbebbb57e1e339 100644 --- a/caffe2/operators/experimental/c10/cpu/add_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/add_cpu.cc @@ -11,14 +11,14 @@ namespace { template void add_op_cpu_impl( - const C10Tensor& A_, - const C10Tensor& B_, - const C10Tensor& C_, + const at::Tensor& A_, + const at::Tensor& B_, + const at::Tensor& C_, bool legacy_broadcast, int axis) { - Tensor A(A_); - Tensor B(B_); - Tensor C(C_); + Tensor A{C10Tensor(A_)}; + Tensor B{C10Tensor(B_)}; + Tensor C{C10Tensor(C_)}; CPUContext context; const DataType* A_data = A.template data(); const DataType* B_data = B.template data(); diff --git a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc index f661c5db12ea54..4223905bc4a799 100644 --- a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc @@ -12,11 +12,12 @@ namespace { template void averaged_loss_op_cpu_impl( - const C10Tensor& X_, - const C10Tensor& sum_, - caffe2::ops::AveragedLoss::State* state) { - Tensor X(X_); - Tensor sum(sum_); + const at::Tensor& X_, + const at::Tensor& sum_, + intrusive_ptr state_) { + Tensor X{C10Tensor(X_)}; + Tensor sum{C10Tensor(sum_)}; + caffe2::ops::AveragedLoss::State* state = state_->GetMutable(); CPUContext context; sum.Resize(vector()); diff --git a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc index 251b9a27c8458a..3786786f99cf0a 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc @@ -12,12 +12,12 @@ namespace { template void batch_gather_op_cpu_impl( - const C10Tensor& data_, - const C10Tensor& indices_, - const C10Tensor& output_) { - Tensor data(data_); - Tensor indices(indices_); - Tensor output(output_); + const at::Tensor& data_, + const at::Tensor& indices_, + const at::Tensor& output_) { + Tensor data{C10Tensor(data_)}; + Tensor indices{C10Tensor(indices_)}; + Tensor output{C10Tensor(output_)}; CPUContext context; CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D"); diff --git a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc index a12b3c6a815c0e..5fa5d52cdc1d63 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc @@ -13,16 +13,17 @@ namespace { template void batch_matmul_op_cpu_impl( - const C10Tensor& A_, - const C10Tensor& B_, - const C10Tensor& Y_, + const at::Tensor& A_, + const at::Tensor& B_, + const at::Tensor& Y_, int trans_a, int trans_b, int broadcast, - caffe2::ops::BatchMatmul::State* state) { - Tensor A(A_); - Tensor B(B_); - Tensor Y(Y_); + intrusive_ptr state_) { + Tensor A{C10Tensor(A_)}; + Tensor B{C10Tensor(B_)}; + Tensor Y{C10Tensor(Y_)}; + caffe2::ops::BatchMatmul::State* state = state_->GetMutable(); CPUContext context; using Engine = caffe2::DefaultEngine; diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index 35b1daa9645875..a236d539e0f2c0 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -23,11 +23,12 @@ void do_cast_(const Tensor& input, const Tensor& output) { template void cast_op_cpu_impl( - const C10Tensor& input_, - const C10Tensor& output_, - TensorProto_DataType to) { - Tensor input(input_); - Tensor output(output_); + const at::Tensor& input_, + const at::Tensor& output_, + int64_t to_) { + Tensor input{C10Tensor(input_)}; + Tensor output{C10Tensor(output_)}; + TensorProto_DataType to = static_cast(to_); switch (to) { case caffe2::TensorProto_DataType_FLOAT: diff --git a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc index a5089d448cce1c..d049ce036dac8e 100644 --- a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc @@ -13,13 +13,13 @@ namespace caffe2 { namespace { template void concat_op_cpu_impl( - at::ArrayRef inputs, - const C10Tensor& output_, - const C10Tensor& split_, + ArrayRef inputs, + const at::Tensor& output_, + const at::Tensor& split_, int axis, int add_axis) { - Tensor output(output_); - Tensor split(split_); + Tensor output{C10Tensor(output_)}; + Tensor split{C10Tensor(split_)}; CPUContext context; split.Resize(vector(1, inputs.size())); diff --git a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc index 60f7b234872956..46df6035e686b4 100644 --- a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc @@ -9,8 +9,8 @@ using caffe2::Tensor; namespace caffe2 { namespace { template -void enforce_finite_op_impl_cpu(const C10Tensor& input_) { - Tensor input(input_); +void enforce_finite_op_impl_cpu(const at::Tensor& input_) { + Tensor input{C10Tensor(input_)}; const DataType* input_data = input.template data(); auto size = input.numel(); diff --git a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc index f4596c5ff01536..cc73d77d90dc3b 100644 --- a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc @@ -9,15 +9,16 @@ namespace caffe2 { namespace { template void expand_dims_op_cpu_impl( - const C10Tensor& input_, - const C10Tensor& output_, - const std::vector& dims, - caffe2::ops::ExpandDims::State* state) { - Tensor input(input_); - Tensor output(output_); + const at::Tensor& input_, + const at::Tensor& output_, + ArrayRef dims, + intrusive_ptr state_) { + Tensor input{C10Tensor(input_)}; + Tensor output{C10Tensor(output_)}; + caffe2::ops::ExpandDims::State* state = state_->GetMutable(); if (!state->initialized) { - state->dims = dims; + state->dims = dims.vec(); auto originalSize = state->dims.size(); CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided."); std::sort(state->dims.begin(), state->dims.end()); diff --git a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc index 9fbbfb73710f84..b22a9776095fdb 100644 --- a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc @@ -13,17 +13,18 @@ namespace caffe2 { namespace { template void fc_op_cpu_impl( - const C10Tensor& X_, - const C10Tensor& W_, - const C10Tensor& b_, - const C10Tensor& Y_, + const at::Tensor& X_, + const at::Tensor& W_, + const at::Tensor& b_, + const at::Tensor& Y_, int axis, int axis_w, - caffe2::ops::FullyConnected::Cache* cache) { - Tensor X(X_); - Tensor W(W_); - Tensor b(b_); - Tensor Y(Y_); + intrusive_ptr state_) { + Tensor X{C10Tensor(X_)}; + Tensor W{C10Tensor(W_)}; + Tensor b{C10Tensor(b_)}; + Tensor Y{C10Tensor(Y_)}; + caffe2::ops::FullyConnected::State* state = state_->GetMutable(); CPUContext context; constexpr bool TransposeWeight = true; @@ -62,12 +63,12 @@ void fc_op_cpu_impl( CAFFE_ENFORCE(N == b.dim32(0), dimErrorString()); CAFFE_ENFORCE(N == b.numel(), dimErrorString()); - cache->Y_shape_cache_ = X.sizes().vec(); + state->Y_shape_cache_ = X.sizes().vec(); // This is an invariant of canonical_axis, so we can DCHECK. - DCHECK_LE(canonical_axis + 1, cache->Y_shape_cache_.size()); - cache->Y_shape_cache_.resize(canonical_axis + 1); - cache->Y_shape_cache_[canonical_axis] = N; - Y.Resize(cache->Y_shape_cache_); + DCHECK_LE(canonical_axis + 1, state->Y_shape_cache_.size()); + state->Y_shape_cache_.resize(canonical_axis + 1); + state->Y_shape_cache_[canonical_axis] = N; + Y.Resize(state->Y_shape_cache_); CAFFE_ENFORCE(M * N == Y.numel(), dimErrorString()); if (X.numel() == 0) { @@ -97,7 +98,7 @@ void fc_op_cpu_impl( static_cast(&context), math_type); // Add bias term - Tensor bias_multiplier(cache->bias_multiplier_); + Tensor bias_multiplier(state->bias_multiplier_); ReinitializeTensor(&bias_multiplier, {M}, at::dtype().device(CPU)); caffe2::math::Set( M, diff --git a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc index 161b6eb597fa2f..dc878b43911b6e 100644 --- a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc @@ -2,21 +2,23 @@ #include "caffe2/operators/experimental/c10/schemas/filler.h" #include "caffe2/utils/math.h" #include "caffe2/core/tensor.h" +#include using caffe2::CPUContext; using caffe2::Tensor; using caffe2::TensorCPU; using std::vector; +using c10::ivalue::TensorList; namespace caffe2 { namespace { void filler_init( - at::ArrayRef inputs, - const C10Tensor& output_, - const std::vector& shape, - const std::vector& extra_shape, + ArrayRef inputs, + const at::Tensor& output_, + ArrayRef shape, + ArrayRef extra_shape, bool input_as_shape) { - Tensor output(output_); + Tensor output{C10Tensor(output_)}; if (inputs.size()) { auto real_shape = vector{}; if (input_as_shape) { @@ -44,14 +46,14 @@ void filler_init( template void given_tensor_fill_op_cpu_impl( - at::ArrayRef inputs, - const C10Tensor& output_, - const std::vector& shape, - const std::vector& extra_shape, + ArrayRef inputs, + const at::Tensor& output_, + ArrayRef shape, + ArrayRef extra_shape, bool input_as_shape, - const C10Tensor& values_) { - Tensor output(output_); - Tensor values(values_); + const at::Tensor& values_) { + Tensor output{C10Tensor(output_)}; + Tensor values{C10Tensor(values_)}; CPUContext context; filler_init(inputs, output_, shape, extra_shape, input_as_shape); @@ -69,14 +71,14 @@ void given_tensor_fill_op_cpu_impl( } void constant_fill_op_cpu_impl( - at::ArrayRef inputs, - const C10Tensor& output_, - const std::vector& shape, - const std::vector& extra_shape, + ArrayRef inputs, + const at::Tensor& output_, + ArrayRef shape, + ArrayRef extra_shape, bool input_as_shape, int dtype, - caffe2::ops::ConstantFill::Value value) { - Tensor output(output_); + c10::IValue value) { + Tensor output{C10Tensor(output_)}; CPUContext context; filler_init(inputs, output_, shape, extra_shape, input_as_shape); @@ -85,25 +87,25 @@ void constant_fill_op_cpu_impl( if (dtype == caffe2::TensorProto_DataType_FLOAT) { caffe2::math::Set( output.numel(), - value.as_float, + value.toDouble(), output.template mutable_data(), static_cast(&context)); } else if (dtype == caffe2::TensorProto_DataType_INT32) { caffe2::math::Set( output.numel(), - value.as_int32, + value.toInt(), output.template mutable_data(), static_cast(&context)); } else if (dtype == caffe2::TensorProto_DataType_INT64) { caffe2::math::Set( output.numel(), - value.as_int64, + value.toInt(), output.template mutable_data(), static_cast(&context)); } else if (dtype == caffe2::TensorProto_DataType_BOOL) { caffe2::math::Set( output.numel(), - value.as_bool, + value.toBool(), output.template mutable_data(), static_cast(&context)); } else { @@ -115,14 +117,14 @@ void constant_fill_op_cpu_impl( } void uniform_fill_op_cpu_impl( - at::ArrayRef inputs, - const C10Tensor& output_, - const std::vector& shape, - const std::vector& extra_shape, + ArrayRef inputs, + const at::Tensor& output_, + ArrayRef shape, + ArrayRef extra_shape, bool input_as_shape, float min, float max) { - Tensor output(output_); + Tensor output{C10Tensor(output_)}; CPUContext context; filler_init(inputs, output_, shape, extra_shape, input_as_shape); diff --git a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc index 09099e0e45257e..a8eb14a28a07f9 100644 --- a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc @@ -10,11 +10,11 @@ namespace caffe2 { namespace { template void flatten_op_cpu_impl( - const C10Tensor& input_, - const C10Tensor& output_, + const at::Tensor& input_, + const at::Tensor& output_, int axis) { - Tensor input(input_); - Tensor output(output_); + Tensor input{C10Tensor(input_)}; + Tensor output{C10Tensor(output_)}; CPUContext context; CAFFE_ENFORCE_GE( input.sizes().size(), axis, "The rank of the tensor must be >= axis."); diff --git a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc index 2be687fa950daa..067eefc8f23769 100644 --- a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc @@ -12,14 +12,14 @@ namespace { template void mul_op_cpu_impl( - const C10Tensor& A_, - const C10Tensor& B_, - const C10Tensor& C_, + const at::Tensor& A_, + const at::Tensor& B_, + const at::Tensor& C_, bool legacy_broadcast, int axis) { - Tensor A(A_); - Tensor B(B_); - Tensor C(C_); + Tensor A{C10Tensor(A_)}; + Tensor B{C10Tensor(B_)}; + Tensor C{C10Tensor(C_)}; CPUContext context; const DataType* A_data = A.template data(); const DataType* B_data = B.template data(); diff --git a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc index a9971dea0f6335..ca66403245f99e 100644 --- a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc @@ -10,10 +10,10 @@ namespace caffe2 { namespace { template void relu_op_cpu_impl( - const C10Tensor& input_, - const C10Tensor& output_) { - Tensor input(input_); - Tensor output(output_); + const at::Tensor& input_, + const at::Tensor& output_) { + Tensor input{C10Tensor(input_)}; + Tensor output{C10Tensor(output_)}; output.ResizeLike(input); diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc index 2f81947fba2585..13bf7d25167e72 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc @@ -10,10 +10,10 @@ namespace caffe2 { namespace { template void sigmoid_op_cpu_impl( - const C10Tensor& input_, - const C10Tensor& output_) { - Tensor input(input_); - Tensor output(output_); + const at::Tensor& input_, + const at::Tensor& output_) { + Tensor input{C10Tensor(input_)}; + Tensor output{C10Tensor(output_)}; output.ResizeLike(input); caffe2::ConstEigenVectorArrayMap xM( diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc index bc7626bba034e9..6c21ae1b7f5411 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc @@ -26,14 +26,14 @@ inline float unjoined_sigmoid_xent_forward(float lgt, float tgt) { } void sigmoid_cross_entropy_with_logits_op_cpu_impl( - const C10Tensor& logits_, - const C10Tensor& targets_, - const C10Tensor& out_, + const at::Tensor& logits_, + const at::Tensor& targets_, + const at::Tensor& out_, bool log_D_trick, bool unjoined_lr_loss) { - Tensor logits(logits_); - Tensor targets(targets_); - Tensor out(out_); + Tensor logits{C10Tensor(logits_)}; + Tensor targets{C10Tensor(targets_)}; + Tensor out{C10Tensor(out_)}; CAFFE_ENFORCE_EQ(logits.sizes(), targets.sizes()); const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1; diff --git a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc index a9f0f303c36890..a58f762ea41547 100644 --- a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc @@ -11,14 +11,14 @@ namespace { template void sparse_lengths_sum_op_cpu_impl( - const C10Tensor& dataInput_, - const C10Tensor& indicesInput_, - const C10Tensor& lengthsInput_, - const C10Tensor& output_) { - Tensor dataInput(dataInput_); - Tensor indicesInput(indicesInput_); - Tensor lengthsInput(lengthsInput_); - Tensor output(output_); + const at::Tensor& dataInput_, + const at::Tensor& indicesInput_, + const at::Tensor& lengthsInput_, + const at::Tensor& output_) { + Tensor dataInput{C10Tensor(dataInput_)}; + Tensor indicesInput{C10Tensor(indicesInput_)}; + Tensor lengthsInput{C10Tensor(lengthsInput_)}; + Tensor output{C10Tensor(output_)}; using T = float; constexpr bool USE_MEAN = false; diff --git a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc index 4c0cc8bd9dfb35..d77e4304968df0 100644 --- a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc @@ -10,10 +10,10 @@ namespace caffe2 { namespace { template void stop_gradient_op_cpu_impl( - const C10Tensor& input_, - const C10Tensor& output_) { - Tensor input(input_); - Tensor output(output_); + const at::Tensor& input_, + const at::Tensor& output_) { + Tensor input{C10Tensor(input_)}; + Tensor output{C10Tensor(output_)}; if (!output.is_same(input)) { output.CopyFrom(input); } diff --git a/caffe2/operators/experimental/c10/schemas/add.h b/caffe2/operators/experimental/c10/schemas/add.h index 75c4a979eba772..fba907ab3160a2 100644 --- a/caffe2/operators/experimental/c10/schemas/add.h +++ b/caffe2/operators/experimental/c10/schemas/add.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" @@ -11,9 +11,9 @@ struct Add final { static constexpr const char* name = "add"; using Signature = void( - const C10Tensor& input1, - const C10Tensor& input2, - const C10Tensor& output, + const at::Tensor& input1, + const at::Tensor& input2, + const at::Tensor& output, bool legacy_broadcast, int axis); diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.cc b/caffe2/operators/experimental/c10/schemas/averaged_loss.cc index ef5bd712be8d5b..1ba38a795a674c 100644 --- a/caffe2/operators/experimental/c10/schemas/averaged_loss.cc +++ b/caffe2/operators/experimental/c10/schemas/averaged_loss.cc @@ -7,6 +7,9 @@ using caffe2::CPUContext; C10_DEFINE_OP_SCHEMA(caffe2::ops::AveragedLoss); namespace caffe2 { + +CAFFE_KNOWN_TYPE(ops::AveragedLoss::State); + REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::AveragedLoss, ops::AveragedLoss::State, diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.h b/caffe2/operators/experimental/c10/schemas/averaged_loss.h index 8000181e6c57f1..5193a6835d567a 100644 --- a/caffe2/operators/experimental/c10/schemas/averaged_loss.h +++ b/caffe2/operators/experimental/c10/schemas/averaged_loss.h @@ -1,24 +1,25 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" #include "caffe2/core/tensor.h" +#include namespace caffe2 { namespace ops { struct AveragedLoss final { struct State final { - C10Tensor scratch = C10Tensor(empty({}, CPU)); + at::Tensor scratch = at::Tensor(C10Tensor(empty({}, CPU))); }; static constexpr const char* name = "averaged_loss"; using Signature = void( - const C10Tensor& input, - const C10Tensor& output, - State* state); + const at::Tensor& input, + const at::Tensor& output, + intrusive_ptr state); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/experimental/c10/schemas/batch_gather.h b/caffe2/operators/experimental/c10/schemas/batch_gather.h index fc4f5ccd2934a2..d745efa35f1969 100644 --- a/caffe2/operators/experimental/c10/schemas/batch_gather.h +++ b/caffe2/operators/experimental/c10/schemas/batch_gather.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" @@ -11,9 +11,9 @@ struct BatchGather final { static constexpr const char* name = "batch_gather"; using Signature = void( - const C10Tensor& data, - const C10Tensor& indices, - const C10Tensor& output); + const at::Tensor& data, + const at::Tensor& indices, + const at::Tensor& output); static constexpr size_t num_dispatch_args() {return 2;} diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.cc b/caffe2/operators/experimental/c10/schemas/batch_matmul.cc index 5e351175749139..49ea333f25d29b 100644 --- a/caffe2/operators/experimental/c10/schemas/batch_matmul.cc +++ b/caffe2/operators/experimental/c10/schemas/batch_matmul.cc @@ -37,6 +37,9 @@ struct BroadcastParameter final { } // namespace namespace caffe2 { + +CAFFE_KNOWN_TYPE(ops::BatchMatmul::State); + REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::BatchMatmul, ops::BatchMatmul::State, diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.h b/caffe2/operators/experimental/c10/schemas/batch_matmul.h index 90d8b16ee10013..7827e97fdc7d52 100644 --- a/caffe2/operators/experimental/c10/schemas/batch_matmul.h +++ b/caffe2/operators/experimental/c10/schemas/batch_matmul.h @@ -1,27 +1,28 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" +#include namespace caffe2 { namespace ops { struct BatchMatmul final { struct State final { - std::shared_ptr scratch; + std::shared_ptr scratch; }; static constexpr const char* name = "batch_matmul"; using Signature = void( - const C10Tensor& A, - const C10Tensor& B, - const C10Tensor& output, + const at::Tensor& A, + const at::Tensor& B, + const at::Tensor& output, int trans_a, int trans_b, int broadcast, - State* state); + intrusive_ptr state); static constexpr size_t num_dispatch_args() {return 2;} diff --git a/caffe2/operators/experimental/c10/schemas/cast.h b/caffe2/operators/experimental/c10/schemas/cast.h index 1e8204bc8e7867..095348b76807b1 100644 --- a/caffe2/operators/experimental/c10/schemas/cast.h +++ b/caffe2/operators/experimental/c10/schemas/cast.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" @@ -11,9 +11,9 @@ struct Cast final { static constexpr const char* name = "cast"; using Signature = void( - const C10Tensor& input1, - const C10Tensor& output, - TensorProto_DataType to); + const at::Tensor& input1, + const at::Tensor& output, + int64_t to_dtype); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/experimental/c10/schemas/concat.h b/caffe2/operators/experimental/c10/schemas/concat.h index 142cf934a36348..9a060bcaa69fa5 100644 --- a/caffe2/operators/experimental/c10/schemas/concat.h +++ b/caffe2/operators/experimental/c10/schemas/concat.h @@ -1,10 +1,11 @@ #pragma once #include -#include +#include #include #include #include "caffe2/core/context_base.h" +#include namespace caffe2 { namespace ops { @@ -13,9 +14,9 @@ struct Concat final { static constexpr const char* name = "concat"; using Signature = void( - at::ArrayRef inputs, - const C10Tensor& output, - const C10Tensor& split_info, + ArrayRef inputs, + const at::Tensor& output, + const at::Tensor& split_info, int add, int add_axis); @@ -25,11 +26,7 @@ struct Concat final { {"inputs", "output", "split_info_output", "add", "add_axis"}}; static c10::DeviceTypeId dispatch_key( - at::ArrayRef inputs, - const C10Tensor& output, - const C10Tensor& split_info, - int add, - int add_axis) { + at::ArrayRef arguments) { return c10::DeviceTypeId::CPU; } }; diff --git a/caffe2/operators/experimental/c10/schemas/enforce_finite.h b/caffe2/operators/experimental/c10/schemas/enforce_finite.h index 2e3f0dabbd7fee..f811e2b6d88d33 100644 --- a/caffe2/operators/experimental/c10/schemas/enforce_finite.h +++ b/caffe2/operators/experimental/c10/schemas/enforce_finite.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace caffe2 { @@ -9,7 +9,7 @@ namespace ops { struct EnforceFinite final { static constexpr const char* name = "enforce_finite"; - using Signature = void(const C10Tensor& input); + using Signature = void(const at::Tensor& input); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.cc b/caffe2/operators/experimental/c10/schemas/expand_dims.cc index 82c261666c79ae..df1f42771b4d38 100644 --- a/caffe2/operators/experimental/c10/schemas/expand_dims.cc +++ b/caffe2/operators/experimental/c10/schemas/expand_dims.cc @@ -3,19 +3,24 @@ #include "caffe2/core/operator_c10wrapper.h" using caffe2::CPUContext; +using c10::intrusive_ptr; +using c10::ivalue::IntList; C10_DEFINE_OP_SCHEMA(caffe2::ops::ExpandDims); namespace { struct DimsParameter final { - using type = std::vector; - static std::vector parse(const caffe2::ArgumentHelper& helper) { - return helper.GetRepeatedArgument("dims"); + using type = intrusive_ptr; + static intrusive_ptr parse(const caffe2::ArgumentHelper& helper) { + return IntList::create(helper.GetRepeatedArgument("dims")); } }; } // namespace namespace caffe2 { + +CAFFE_KNOWN_TYPE(ops::ExpandDims::State); + REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::ExpandDims, ops::ExpandDims::State, diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.h b/caffe2/operators/experimental/c10/schemas/expand_dims.h index a4721892c66e7e..53cb7fc60a7af9 100644 --- a/caffe2/operators/experimental/c10/schemas/expand_dims.h +++ b/caffe2/operators/experimental/c10/schemas/expand_dims.h @@ -1,25 +1,27 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" +#include +#include namespace caffe2 { namespace ops { struct ExpandDims final { struct State { - std::vector dims; + std::vector dims; bool initialized = false; }; static constexpr const char* name = "expand_dims"; using Signature = void( - const C10Tensor& input, - const C10Tensor& output, - const std::vector& dims, - State* state); + const at::Tensor& input, + const at::Tensor& output, + ArrayRef dims, + intrusive_ptr state); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/experimental/c10/schemas/fc.cc b/caffe2/operators/experimental/c10/schemas/fc.cc index a081c735b791a2..83092d27373456 100644 --- a/caffe2/operators/experimental/c10/schemas/fc.cc +++ b/caffe2/operators/experimental/c10/schemas/fc.cc @@ -28,9 +28,12 @@ struct AxisWParameter final { } // namespace namespace caffe2 { + +CAFFE_KNOWN_TYPE(ops::FullyConnected::State); + REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS( ops::FullyConnected, - ops::FullyConnected::Cache, + ops::FullyConnected::State, C10FC_DontUseThisOpYet, ParameterHelper, ParameterHelper) diff --git a/caffe2/operators/experimental/c10/schemas/fc.h b/caffe2/operators/experimental/c10/schemas/fc.h index 8730c3a5ec78f0..57d15bd81d8407 100644 --- a/caffe2/operators/experimental/c10/schemas/fc.h +++ b/caffe2/operators/experimental/c10/schemas/fc.h @@ -1,8 +1,9 @@ #pragma once -#include +#include #include #include "caffe2/core/tensor.h" +#include namespace caffe2 { namespace ops { @@ -10,19 +11,19 @@ namespace ops { struct FullyConnected final { static constexpr const char* name = "FC"; - struct Cache final { + struct State final { vector Y_shape_cache_; - C10Tensor bias_multiplier_ = C10Tensor(Tensor()); + at::Tensor bias_multiplier_ = at::Tensor(C10Tensor(Tensor())); }; using Signature = void( - const C10Tensor& X, - const C10Tensor& W, - const C10Tensor& b, - const C10Tensor& output, + const at::Tensor& X, + const at::Tensor& W, + const at::Tensor& b, + const at::Tensor& output, int axis, int axis_w, - Cache* cache); + intrusive_ptr state); static constexpr size_t num_dispatch_args() {return 3;} diff --git a/caffe2/operators/experimental/c10/schemas/filler.cc b/caffe2/operators/experimental/c10/schemas/filler.cc index e0a0e596304fec..c104c628c320aa 100644 --- a/caffe2/operators/experimental/c10/schemas/filler.cc +++ b/caffe2/operators/experimental/c10/schemas/filler.cc @@ -5,6 +5,8 @@ using caffe2::CPUContext; using c10::C10Tensor; +using c10::ivalue::IntList; +using c10::intrusive_ptr; C10_DEFINE_OP_SCHEMA(caffe2::ops::ConstantFill); C10_DEFINE_OP_SCHEMA(caffe2::ops::UniformFill); @@ -15,15 +17,15 @@ C10_DEFINE_OP_SCHEMA(caffe2::ops::GivenTensorFill); namespace { struct ShapeParameter final { - using type = std::vector; - static std::vector parse(const caffe2::ArgumentHelper& helper) { - return helper.GetRepeatedArgument("shape"); + using type = intrusive_ptr; + static intrusive_ptr parse(const caffe2::ArgumentHelper& helper) { + return IntList::create(helper.GetRepeatedArgument("shape")); } }; struct ExtraShapeParameter final { - using type = std::vector; - static std::vector parse(const caffe2::ArgumentHelper& helper) { - return helper.GetRepeatedArgument("extra_shape"); + using type = intrusive_ptr; + static intrusive_ptr parse(const caffe2::ArgumentHelper& helper) { + return IntList::create(helper.GetRepeatedArgument("extra_shape")); } }; struct InputAsShapeParameter final { @@ -54,20 +56,20 @@ struct DTypeParameter final { } }; struct ValueParameter final { - using type = caffe2::ops::ConstantFill::Value; - static caffe2::ops::ConstantFill::Value parse( + using type = c10::IValue; + static c10::IValue parse( const caffe2::ArgumentHelper& helper) { - caffe2::ops::ConstantFill::Value result; + c10::IValue result; if (helper.HasSingleArgumentOfType("value")) { - result.as_float = helper.GetSingleArgument("value", 0); + result = helper.GetSingleArgument("value", 0); } else if (helper.HasSingleArgumentOfType("value")) { - result.as_int32 = helper.GetSingleArgument("value", 0); + result = helper.GetSingleArgument("value", 0); } else if (helper.HasSingleArgumentOfType("value")) { - result.as_int64 = helper.GetSingleArgument("value", 0); + result = helper.GetSingleArgument("value", 0); } else if (helper.HasSingleArgumentOfType("value")) { - result.as_bool = helper.GetSingleArgument("value", false); + result = helper.GetSingleArgument("value", false); } else { - result.as_float = 0.0; + result = 0.0; } return result; } @@ -86,8 +88,8 @@ struct MaxParameter final { }; template struct ValuesParameter final { - using type = C10Tensor; - static C10Tensor parse(const caffe2::ArgumentHelper& helper) { + using type = at::Tensor; + static at::Tensor parse(const caffe2::ArgumentHelper& helper) { if (!std::is_same::value || !helper.HasArgument("dtype")) { return ExtractValues(helper); } else { @@ -115,7 +117,7 @@ struct ValuesParameter final { private: template - static C10Tensor ExtractValues( + static at::Tensor ExtractValues( const caffe2::ArgumentHelper& helper) { auto source_values = helper.GetRepeatedArgument("values"); caffe2::Tensor values{caffe2::CPU}; @@ -125,7 +127,7 @@ struct ValuesParameter final { values_data[i] = static_cast(source_values[i]); } // body_ = &GivenTensorFillOp::FillWithType; - return C10Tensor(values); + return at::Tensor(C10Tensor(values)); } }; } // namespace diff --git a/caffe2/operators/experimental/c10/schemas/filler.h b/caffe2/operators/experimental/c10/schemas/filler.h index 4a843d978e3d52..359d6abaa723a6 100644 --- a/caffe2/operators/experimental/c10/schemas/filler.h +++ b/caffe2/operators/experimental/c10/schemas/filler.h @@ -1,7 +1,8 @@ #pragma once #include -#include +#include +#include #include #include #include "caffe2/core/context_base.h" @@ -18,12 +19,12 @@ struct GivenTensorFill final { static constexpr const char* name = "given_tensor_fill"; using Signature = void( - at::ArrayRef inputs, - const C10Tensor& output, - const std::vector& shape, - const std::vector& extra_shape, + ArrayRef inputs, + const at::Tensor& output, + ArrayRef shape, + ArrayRef extra_shape, bool input_as_shape, - const C10Tensor& values); + const at::Tensor& values); static constexpr c10::guts::array parameter_names = { {"inputs", @@ -36,33 +37,22 @@ struct GivenTensorFill final { static constexpr size_t num_outputs() {return 1;} static c10::DeviceTypeId dispatch_key( - at::ArrayRef inputs, - const C10Tensor& output, - const std::vector& shape, - const std::vector& extra_shape, - bool input_as_shape, - const C10Tensor& values) { + c10::ArrayRef args) { return c10::DeviceTypeId::CPU; } }; struct ConstantFill final { - union Value { - float as_float; - int32_t as_int32; - int64_t as_int64; - bool as_bool; - }; static constexpr const char* name = "constant_fill"; using Signature = void( - at::ArrayRef inputs, - const C10Tensor& output, - const std::vector& shape, - const std::vector& extra_shape, + ArrayRef inputs, + const at::Tensor& output, + ArrayRef shape, + ArrayRef extra_shape, bool input_as_shape, int dtype, - Value value); + IValue value); static constexpr size_t num_outputs() {return 1;} @@ -76,13 +66,7 @@ struct ConstantFill final { "value"}}; static c10::DeviceTypeId dispatch_key( - at::ArrayRef inputs, - const C10Tensor& output, - const std::vector& shape, - const std::vector& extra_shape, - bool input_as_shape, - int dtype, - Value value) { + c10::ArrayRef args) { return c10::DeviceTypeId::CPU; } }; @@ -91,10 +75,10 @@ struct UniformFill final { static constexpr const char* name = "uniform_fill"; using Signature = void( - at::ArrayRef inputs, - const C10Tensor& output, - const std::vector& shape, - const std::vector& extra_shape, + ArrayRef inputs, + const at::Tensor& output, + ArrayRef shape, + ArrayRef extra_shape, bool input_as_shape, float min, float max); @@ -111,13 +95,7 @@ struct UniformFill final { "max"}}; static c10::DeviceTypeId dispatch_key( - at::ArrayRef inputs, - const C10Tensor& output, - const std::vector& shape, - const std::vector& extra_shape, - bool input_as_shape, - float min, - float max) { + c10::ArrayRef args) { return c10::DeviceTypeId::CPU; } }; diff --git a/caffe2/operators/experimental/c10/schemas/flatten.h b/caffe2/operators/experimental/c10/schemas/flatten.h index 31622d6f421c14..0ee1773cb6af6f 100644 --- a/caffe2/operators/experimental/c10/schemas/flatten.h +++ b/caffe2/operators/experimental/c10/schemas/flatten.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" @@ -11,8 +11,8 @@ struct Flatten final { static constexpr const char* name = "flatten"; using Signature = void( - const C10Tensor& input, - const C10Tensor& output, + const at::Tensor& input, + const at::Tensor& output, int axis); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/experimental/c10/schemas/mul.h b/caffe2/operators/experimental/c10/schemas/mul.h index 6d7bdffd269407..12a178033031f1 100644 --- a/caffe2/operators/experimental/c10/schemas/mul.h +++ b/caffe2/operators/experimental/c10/schemas/mul.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" @@ -11,9 +11,9 @@ struct Mul final { static constexpr const char* name = "mul"; using Signature = void( - const C10Tensor& input1, - const C10Tensor& input2, - const C10Tensor& output, + const at::Tensor& input1, + const at::Tensor& input2, + const at::Tensor& output, bool legacy_broadcast, int axis); diff --git a/caffe2/operators/experimental/c10/schemas/relu.h b/caffe2/operators/experimental/c10/schemas/relu.h index 19606f8f8afb3a..bf1b8fd03cc6c8 100644 --- a/caffe2/operators/experimental/c10/schemas/relu.h +++ b/caffe2/operators/experimental/c10/schemas/relu.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace caffe2 { @@ -10,7 +10,7 @@ struct Relu final { static constexpr const char* name = "relu"; using Signature = - void(const C10Tensor& input, const C10Tensor& output); + void(const at::Tensor& input, const at::Tensor& output); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid.h b/caffe2/operators/experimental/c10/schemas/sigmoid.h index ad70d05c19fa9d..326dc078f4a2d8 100644 --- a/caffe2/operators/experimental/c10/schemas/sigmoid.h +++ b/caffe2/operators/experimental/c10/schemas/sigmoid.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace caffe2 { @@ -10,7 +10,7 @@ struct Sigmoid final { static constexpr const char* name = "sigmoid"; using Signature = - void(const C10Tensor& input, const C10Tensor& output); + void(const at::Tensor& input, const at::Tensor& output); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h index 7e2d8d727d8f6f..7fb7a88ece1fc4 100644 --- a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h +++ b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace caffe2 { @@ -10,9 +10,9 @@ struct SigmoidCrossEntropyWithLogits final { static constexpr const char* name = "sigmoid_cross_entropy_with_logits"; using Signature = void( - const C10Tensor& input1, - const C10Tensor& input2, - const C10Tensor& output, + const at::Tensor& input1, + const at::Tensor& input2, + const at::Tensor& output, bool log_D_trick, bool unjoined_lr_loss); diff --git a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h index 33f96553abbdbf..16d23e733e0a34 100644 --- a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h +++ b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace caffe2 { @@ -10,10 +10,10 @@ struct SparseLengthsSum final { static constexpr const char* name = "sparse_lengths_sum"; using Signature = void( - const C10Tensor& data, - const C10Tensor& indices, - const C10Tensor& lengths, - const C10Tensor& output); + const at::Tensor& data, + const at::Tensor& indices, + const at::Tensor& lengths, + const at::Tensor& output); static constexpr size_t num_dispatch_args() {return 3;} diff --git a/caffe2/operators/experimental/c10/schemas/stop_gradient.h b/caffe2/operators/experimental/c10/schemas/stop_gradient.h index 7c17765b2adc46..f89f942eb706a6 100644 --- a/caffe2/operators/experimental/c10/schemas/stop_gradient.h +++ b/caffe2/operators/experimental/c10/schemas/stop_gradient.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include "caffe2/core/context_base.h" @@ -11,8 +11,8 @@ struct StopGradient final { static constexpr const char* name = "stop_gradient"; using Signature = void( - const C10Tensor& input, - const C10Tensor& output); + const at::Tensor& input, + const at::Tensor& output); static constexpr size_t num_dispatch_args() {return 1;} diff --git a/caffe2/operators/layer_norm_op.cc b/caffe2/operators/layer_norm_op.cc index d02102248090d6..ca39be5e5e847d 100644 --- a/caffe2/operators/layer_norm_op.cc +++ b/caffe2/operators/layer_norm_op.cc @@ -2,6 +2,7 @@ #include "caffe2/utils/eigen_utils.h" #include #include +#include namespace caffe2 { @@ -187,23 +188,24 @@ to the end.) namespace { template void layer_norm_c10( - const c10::C10Tensor& X_, - const c10::C10Tensor& Y_, - const c10::C10Tensor& mean_, - const c10::C10Tensor& sig_, + const at::Tensor& X_, + const at::Tensor& Y_, + const at::Tensor& mean_, + const at::Tensor& sig_, int axis, float epsilon, - c10::core::opschema::LayerNorm::Cache* cache) { - caffe2::Tensor X(X_); - caffe2::Tensor Y(Y_); - caffe2::Tensor mean(mean_); - caffe2::Tensor sig(sig_); + c10::intrusive_ptr cache_) { + caffe2::Tensor X{c10::C10Tensor(X_)}; + caffe2::Tensor Y{c10::C10Tensor(Y_)}; + caffe2::Tensor mean{c10::C10Tensor(mean_)}; + caffe2::Tensor sig{c10::C10Tensor(sig_)}; caffe2::CPUContext context; + c10::core::opschema::LayerNorm::Cache* cache = cache_->GetMutable(); if (!cache->scale.has_value()) { - cache->scale = c10::C10Tensor(caffe2::Tensor{caffe2::CPU}); + cache->scale = at::Tensor(c10::C10Tensor(caffe2::Tensor{caffe2::CPU})); } if (!cache->bias.has_value()) { - cache->bias = c10::C10Tensor(caffe2::Tensor{caffe2::CPU}); + cache->bias = at::Tensor(c10::C10Tensor(caffe2::Tensor{caffe2::CPU})); } caffe2::Tensor scale(*cache->scale); caffe2::Tensor bias(*cache->bias); diff --git a/torch/csrc/jit/c10_ops/layer_norm.cpp b/torch/csrc/jit/c10_ops/layer_norm.cpp index 705f3343c51110..c23aa07ff04732 100644 --- a/torch/csrc/jit/c10_ops/layer_norm.cpp +++ b/torch/csrc/jit/c10_ops/layer_norm.cpp @@ -1,9 +1,12 @@ #include #include +#include #include #include -using c10::C10Tensor; +using at::Tensor; +using c10::IValue; +using c10::ArrayRef; namespace { // TODO Return tuple instead of vector @@ -26,12 +29,24 @@ std::vector layer_norm( if (input.requires_grad()) { throw std::runtime_error("Autograd not yet supported for c10 ops."); } - c10::core::opschema::LayerNorm::Cache cache; - C10Tensor c10_input(torch::autograd::Variable(std::move(input)).data()); - C10Tensor c10_output(at::empty({0})); - C10Tensor c10_output_mean(at::empty({0})); - C10Tensor c10_output_stdev(at::empty({0})); - c10::Dispatcher::call(c10_input, c10_output, c10_output_mean, c10_output_stdev, (int)axis, (float)epsilon, &cache); + + c10::intrusive_ptr cache = c10::make_intrusive(); + cache->GetMutable(); // initialize cache + + Tensor c10_input(torch::autograd::Variable(std::move(input)).data()); + Tensor c10_output(at::empty({0})); + Tensor c10_output_mean(at::empty({0})); + Tensor c10_output_stdev(at::empty({0})); + + c10::Dispatcher::call(ArrayRef{ + IValue(c10_input), + IValue(c10_output), + IValue(c10_output_mean), + IValue(c10_output_stdev), + IValue(axis), + IValue(epsilon), + IValue(cache) + }); return { torch::autograd::make_variable(at::Tensor(std::move(c10_output)), false), torch::autograd::make_variable(at::Tensor(std::move(c10_output_mean)), false),