Skip to content

Commit

Permalink
Make c10 dispatcher use boxed kernel function pointers (pytorch#16051)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#16051
This changes the kernels stored in the c10 dispatcher from plain C function pointers to IValue-based KernelFunction*.

Note that KernelFunction is currently taking an `ArrayRef<IValue>` as arguments. A later diff will change that to it taking a `Stack*`.

Reviewed By: ezyang

Differential Revision: D13684518

fbshipit-source-id: 1fa54f60cec2e967b92a4a043d6e3ac1627ed991
  • Loading branch information
smessmer authored and facebook-github-bot committed Jan 19, 2019
1 parent b662a9b commit c904416
Show file tree
Hide file tree
Showing 49 changed files with 501 additions and 390 deletions.
38 changes: 18 additions & 20 deletions aten/src/ATen/core/dispatch/DispatchTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <c10/util/LeftRight.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/flat_hash_map.h>
#include <ATen/core/ivalue.h>

#include <array>
#include <atomic>
Expand All @@ -20,9 +21,9 @@ template <class Key>
class ThreadsafeOperatorTable_ final {
public:
template <class Key_>
void emplace(Key_&& key, void* value) {
bool res = map_.write([&](ska::flat_hash_map<Key, void*>& map) -> bool {
auto result = map.emplace(std::forward<Key>(key), value);
void emplace(Key_&& key, KernelFunction value) {
bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> bool {
auto result = map.emplace(std::forward<Key>(key), std::move(value));
return result.second;
});
if (!res) {
Expand All @@ -34,7 +35,7 @@ class ThreadsafeOperatorTable_ final {

void erase(const Key& key) {
auto num_removed =
map_.write([&](ska::flat_hash_map<Key, void*>& map) -> size_t {
map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> size_t {
return map.erase(key);
});
assert(num_removed <= 1); // This is not a multi-map
Expand All @@ -44,19 +45,19 @@ class ThreadsafeOperatorTable_ final {
}
}

void* lookup(const Key& key) const {
return map_.read([&](const ska::flat_hash_map<Key, void*>& map) -> void* {
const KernelFunction* lookup(const Key& key) const {
return map_.read([&](const ska::flat_hash_map<Key, KernelFunction>& map) -> const KernelFunction* {
auto found = map.find(key);
if (found != map.end()) {
return found->second;
return &found->second;
} else {
return nullptr;
}
});
}

private:
LeftRight<ska::flat_hash_map<Key, void*>> map_;
LeftRight<ska::flat_hash_map<Key, KernelFunction>> map_;
};
} // namespace details

Expand Down Expand Up @@ -86,9 +87,9 @@ class DispatchTable final {
* @param dispatch_key Dispatch key to define when this kernel is selected
*/
void registerKernel(
typename Schema::signature::func_type* func,
KernelFunction func,
typename Schema::dispatch::dispatch_key_type dispatch_key) {
kernels_.emplace(std::move(dispatch_key), reinterpret_cast<void*>(func));
kernels_.emplace(std::move(dispatch_key), std::move(func));
}

/**
Expand All @@ -111,31 +112,28 @@ class DispatchTable final {
* @param args Arguments to invoke the function with
* @return Returned value of the operator
*/
template <class... Args>
typename Schema::signature::return_type call(Args&&... args) const {
IValue call(ArrayRef<IValue> args) const {
// TODO Better error message, but need to take care that reference arguments
// match non-reference arguments and so on.
// static_assert(std::is_same<typename Schema::return_type (Args...),
// typename Schema::func_type>::value, "Argument types don't match
// operator signature");
auto kernel_func = lookupKernelFunc_(args...);
return kernel_func(std::forward<Args>(args)...);
const auto& kernel_func = lookupKernelFunc_(args);
return kernel_func(args);
}

private:
template <class... Args>
typename Schema::signature::func_type* lookupKernelFunc_(
const Args&... args) const {
auto dispatch_key = Schema::dispatch::dispatch_key(args...);
void* found = kernels_.lookup(dispatch_key);
const KernelFunction& lookupKernelFunc_(ArrayRef<IValue> args) const {
auto dispatch_key = Schema::dispatch::dispatch_key(args);
const KernelFunction* found = kernels_.lookup(dispatch_key);
if (found == nullptr) {
// TODO Better error message - include op name and dispatch key (i.e.
// argument types)
throw std::logic_error(
std::string() + "Didn't find kernel to dispatch to for operator '" +
Schema::metadata::name() + "'");
}
return reinterpret_cast<typename Schema::signature::func_type*>(found);
return *found;
}

details::ThreadsafeOperatorTable_<
Expand Down
32 changes: 8 additions & 24 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,35 @@ namespace c10 {
*/
template<class OpSchemaDef>
class Dispatcher final {
private:
using Schema = OpSchema<OpSchemaDef>;
public:
// Implementation note: this class abstracts over the fact that we have per-operator
// dispatch tables. This could be easily adjusted to have a single global hash
// table.

/**
* Register an operator to the dispatch table for some operator schema.
*
* @tparam OpSchemaDef Operator schema to register this operator to (mandatory)
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp (inferred)
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::registerOp
* @return void
*/
template<class... Args>
static void registerKernel(Args&&... args) {
static void registerKernel(KernelFunction kernel_func, typename Schema::dispatch::dispatch_key_type dispatch_key) {
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
return dispatch_table_for_this_op.registerKernel(std::forward<Args>(args)...);
return dispatch_table_for_this_op.registerKernel(std::move(kernel_func), std::move(dispatch_key));
}

/**
* Remove an operator from the dispatch table for some operator schema.
*
* @tparam OpSchemaDef Operator schema to deregister from (mandatory)
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp (inferred)
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::deregisterOp
* @return void
*/
template<class... Args>
static void deregisterKernel(Args&&... args) {
static void deregisterKernel(const typename Schema::dispatch::dispatch_key_type& dispatch_key) {
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
return dispatch_table_for_this_op.deregisterKernel(std::forward<Args>(args)...);
return dispatch_table_for_this_op.deregisterKernel(dispatch_key);
}

/**
* Perform a dynamic dispatch to some operator
*
* @tparam OpSchemaDef Operator schema to dispatch with (mandatory)
* @tparam Args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call (inferred)
* @param args Perfect-forwarding args to c10::dispatch::impl::DispatchTable::call
* @return Return type of this operator
*/
template<class... Args>
static typename OpSchema<OpSchemaDef>::signature::return_type call(Args&&... args) {
static IValue call(ArrayRef<IValue> args) {
auto& dispatch_table_for_this_op = c10_dispatch_table<OpSchemaDef>();
return dispatch_table_for_this_op.call(std::forward<Args>(args)...);
return dispatch_table_for_this_op.call(args);
}
};

Expand Down
30 changes: 19 additions & 11 deletions aten/src/ATen/core/dispatch/KernelRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class KernelRegistrar final {
* @param kernel The concrete function implementation to register
* @param dispatch_key The dispatch key to register the function to
*/
KernelRegistrar(typename Schema::signature::func_type* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
KernelRegistrar(KernelFunction kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
: dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
Dispatcher<OpSchemaDef>::registerKernel(kernel, dispatch_key_);
}
Expand Down Expand Up @@ -78,8 +78,7 @@ class KernelRegistrar final {
* The resulting full expression is implicitly convertible to a KernelRegistrar.
*
* @tparam OpSchemaDef The operator schema this is building a KernelRegistration for
* @tparam hasKernel Boolean for compile-time checking that a kernel is specified before finalizing the builder
* @tparam hasDispatchKey Boolean for compile-time checking thhat a dispatch key is specified before finalizing the builder
* @tparam FieldsPresentFlags Remembers which fields are already set in the builder
*/
template<class OpSchemaDef, uint64_t FieldsPresentFlags>
class KernelRegistrationBuilder final {
Expand All @@ -89,15 +88,15 @@ class KernelRegistrationBuilder final {
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0;
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1;

c10::optional<typename Schema::signature::func_type*> kernel_;
c10::optional<KernelFunction> kernel_;
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;

public:
constexpr KernelRegistrationBuilder()
KernelRegistrationBuilder()
: KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {}

constexpr KernelRegistrationBuilder(
c10::optional<typename Schema::signature::func_type*> kernel,
KernelRegistrationBuilder(
c10::optional<KernelFunction> kernel,
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
: kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {}

Expand All @@ -106,7 +105,7 @@ class KernelRegistrationBuilder final {
* creates the object.
* @return Produced KernelRegistrar
*/
constexpr operator KernelRegistrar<OpSchemaDef>() && {
operator KernelRegistrar<OpSchemaDef>() && {
static_assert(FieldsPresentFlags & KERNEL_PRESENT, "Forgot to call .kernel() in kernel registration");
static_assert(FieldsPresentFlags & DISPATCH_KEY_PRESENT, "Forgot to call .dispatchKey() in kernel registration");
return KernelRegistrar<OpSchemaDef>(std::move(*kernel_), std::move(*dispatch_key_));
Expand All @@ -117,17 +116,26 @@ class KernelRegistrationBuilder final {
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
constexpr KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(typename Schema::signature::func_type* kernel_func) && {
KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(KernelFunction kernel_func) && {
static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration");
return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(*kernel_func, std::move(dispatch_key_));
return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(std::move(kernel_func), std::move(dispatch_key_));
}

/**
* Specify the concrete function implementation for this dispatch registration
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(typename Schema::signature::func_type* kernel_func) && {
return std::move(*this).kernel(Schema::signature::wrap_kernel(kernel_func));
}

/**
* Specify the dispatch key for this dispatch registration
* @param dispatch_key dispatch key to register the function to
* @return "this" for method chaining
*/
constexpr KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && {
KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(typename Schema::dispatch::dispatch_key_type dispatch_key) && {
static_assert(!(FieldsPresentFlags & DISPATCH_KEY_PRESENT), "Tried to define kernel twice in same op registration");
return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(kernel_), std::move(dispatch_key));
}
Expand Down
Loading

0 comments on commit c904416

Please sign in to comment.