Skip to content

Commit

Permalink
OpenCL Improvements
Browse files Browse the repository at this point in the history
* Registers Scatter and ScatterNd Ops for SYCL

* Registers Stack op for SYCL

* Fixes No sycl buffer found error for debug ops

* Registers MatMul and Transpose Ops to SYCL device for double

* Extends analyzer_cli_test.py test to cover SYCL

* Fixes Transpose Op for double when on SYCL

* Bumps Eigen version to fix double precision issue on SYCL

* Extends SessionDebugTestBase to cover SYCL
  • Loading branch information
Luke Iwanski authored and benoitsteiner committed Feb 25, 2017
1 parent cbcdc6e commit eb0d3a1
Show file tree
Hide file tree
Showing 15 changed files with 313 additions and 12 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,7 @@ cc_library(
hdrs = if_not_windows([
"common_runtime/sycl/sycl_allocator.h",
"common_runtime/sycl/sycl_device.h",
"common_runtime/sycl/sycl_util.h",
"common_runtime/sycl/sycl_device_context.h",
]),
copts = tf_copts(),
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/core/common_runtime/sycl/sycl_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if !TENSORFLOW_USE_SYCL
#error This file must only be included when building TensorFlow with SYCL support
#endif

#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_

// For DMA helper
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"

namespace tensorflow {
inline void* GetBase(const Tensor* src) {
return const_cast<void*>(DMAHelper::base(src));
}

inline void* GetBase(Tensor* dst) { return DMAHelper::base(dst); }
}

#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_UTIL_H_
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ load(
"tf_kernel_library",
"cc_header_only_library",
)
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
load(
Expand Down Expand Up @@ -465,7 +466,7 @@ tf_kernel_library(
deps = ARRAY_DEPS + [
"//tensorflow/core:gpu_runtime",
"//tensorflow/core/debug:debug_io_utils",
],
] + if_sycl(["//tensorflow/core:sycl_runtime"]),
)

cc_library(
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/core/kernels/debug_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ limitations under the License.
#define TENSORFLOW_KERNELS_DEBUG_OP_H_

#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#ifdef TENSORFLOW_USE_SYCL
#include "tensorflow/core/common_runtime/sycl/sycl_util.h"
#endif // TENSORFLOW_USE_SYCL
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand Down Expand Up @@ -63,6 +66,20 @@ class CopyOp : public OpKernel {
// The input tensor is on the host (CPU): deep-copy from CPU to CPU.
*copied_tensor = tensor::DeepCopy(src_tensor);
}
#elif defined(TENSORFLOW_USE_SYCL)
Device* device = static_cast<Device*>(context->device());
// Determine if the input tensor is not on CPU (e.g., on GPU).
bool off_host_input = device->device_type() == DEVICE_SYCL &&
!context->input_alloc_attr(0).on_host();
if(off_host_input) {
auto size = src_tensor.NumElements() * sizeof(src_tensor.dtype());
auto dst_ptr = GetBase(copied_tensor);
auto src_ptr = GetBase(&src_tensor);
typedef decltype(src_tensor.dtype()) ttype;
device->eigen_sycl_device()->memcpy(dst_ptr, static_cast<const ttype *>(src_ptr), size);
} else {
*copied_tensor = tensor::DeepCopy(src_tensor);
}
#else
*copied_tensor = tensor::DeepCopy(src_tensor);
#endif
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ TF_CALL_half(REGISTER_GPU);
.Label("eigen"), \
MatMulOp<SYCLDevice, T, false /* xxblas */>)
TF_CALL_float(REGISTER_SYCL);
TF_CALL_double(REGISTER_SYCL);

#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
69 changes: 69 additions & 0 deletions tensorflow/core/kernels/scatter_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,50 @@ struct Assign<scatter_op::UpdateOp::DIV> {
}
};

#ifdef TENSORFLOW_USE_SYCL
template <scatter_op::UpdateOp Op>
struct AssignSYCL {};
template <>
struct AssignSYCL<scatter_op::UpdateOp::ASSIGN> {
template <typename Device, typename Params, typename Update>
static void Run(Device d, Params p, Update u) {
p.device(d) = u;
}
};

template <>
struct AssignSYCL<scatter_op::UpdateOp::ADD> {
template <typename Device, typename Params, typename Update>
static void Run(Device d, Params p, Update u) {
p.device(d) += u;
}
};

template <>
struct AssignSYCL<scatter_op::UpdateOp::SUB> {
template <typename Device, typename Params, typename Update>
static void Run(Device d, Params p, Update u) {
p.device(d) -= u;
}
};

template <>
struct AssignSYCL<scatter_op::UpdateOp::MUL> {
template <typename Device, typename Params, typename Update>
static void Run(Device d, Params p, Update u) {
p.device(d) = p * u;
}
};

template <>
struct AssignSYCL<scatter_op::UpdateOp::DIV> {
template <typename Device, typename Params, typename Update>
static void Run(Device d, Params p, Update u) {
p.device(d) = p / u;
}
};
#endif // TENSORFLOW_USE_SYCL

} // namespace internal
} // namespace scatter_op

Expand Down Expand Up @@ -110,6 +154,31 @@ struct ScatterFunctorBase {
}
};

#ifdef TENSORFLOW_USE_SYCL
template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctorBase <SYCLDevice, T, Index, op> {
Index operator()(OpKernelContext* c, const SYCLDevice& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
// indices and params sizes were validated in DoCompute().
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
scatter_op::internal::AssignSYCL<op>::Run(d, params.template chip<0>(index),
updates.template chip<0>(i));
}
return -1;
}
};
#endif // TENSORFLOW_USE_SYCL

template <typename T, typename Index>
struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
Index operator()(OpKernelContext* c, const CPUDevice& d,
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL

// Check whether updates.shape = indices.shape[:batch_dim] +
// params_shape[slice_dim:]
Expand Down Expand Up @@ -415,6 +418,19 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);

#endif // GOOGLE_CUDA

#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \
REGISTER_SCATTER_ND_ADD_SUB(type, SYCL);

#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
REGISTER_SCATTER_ND_UPDATE(type, SYCL);

TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
#undef REGISTER_SCATTER_ND_UPDATE_SYCL
#endif // TENSORFLOW_USE_SYCL

#undef REGISTER_SCATTER_ND_ADD
#undef REGISTER_SCATTER_ND_ADD_SUB
#undef REGISTER_SCATTER_ND_ADD_SUB_CPU
Expand Down
88 changes: 88 additions & 0 deletions tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ limitations under the License.
namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL

class OpKernelContext;

Expand Down Expand Up @@ -186,6 +189,91 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
#undef REGISTER_SCATTER_ND_INDEX
#undef REGISTER_SCATTER_ND_FULL

#ifdef TENSORFLOW_USE_SYCL

// Implementation of update functor for SYCL.
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
Index operator()(
const SYCLDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
typename TTypes<T, 2>::Tensor Tparams,
typename TTypes<Index, 2>::ConstTensor Tindices,
typename TTypes<T, 2>::ConstTensor Tupdates,
typename TTypes<T, 2>::Tensor Toutput) {
// error_loc is -1 if there's no out-of-bounds index,
// otherwise it is the location of an OOB index in Tindices.
Index error_loc = -1;

const Eigen::DenseIndex batch_size = Tindices.dimension(0);

Index batch_strides[IXDIM];
for (int dim = IXDIM - 1; dim >= 0; --dim) {
if (dim == IXDIM - 1) {
batch_strides[dim] = 1;
} else {
batch_strides[dim] =
batch_strides[dim + 1] * output_shape_prefix[dim + 1];
}
}

for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
Index i = 0;
bool out_of_bounds = false;
for (int dim = 0; dim < IXDIM; ++dim) {
const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
i += ix_d * batch_strides[dim];
}
if (TF_PREDICT_FALSE(out_of_bounds)) {
error_loc = loc;
break;
} else {
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip.device(d);
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
decltype(input_chip), decltype(update_chip), decltype(output_chip),
OP>::Execute(input_chip, update_chip, output_chip);
}
}

return error_loc;
}
};

#define REGISTER_SCATTER_ND_FULL_SYCL(T, Index, op) \
template Index \
ScatterNdFunctor<SYCLDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
const SYCLDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)

#define REGISTER_SCATTER_ND_INDEX_SYCL(type, op) \
REGISTER_SCATTER_ND_FULL_SYCL(type, int32, op); \
REGISTER_SCATTER_ND_FULL_SYCL(type, int64, op)

#define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ASSIGN);

#define REGISTER_SCATTER_ND_MATH_SYCL(type) \
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB);

TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL)

#undef REGISTER_SCATTER_ND_MATH_SYCL
#undef REGISTER_SCATTER_ND_UPDATE_SYCL
#undef REGISTER_SCATTER_ND_INDEX_SYCL
#undef REGISTER_SCATTER_ND_FULL_SYCL

#endif // TENSORFLOW_USE_SYCL

} // namespace functor

} // namespace tensorflow
Expand Down
Loading

0 comments on commit eb0d3a1

Please sign in to comment.