From e3fb7d623633d0c425355b7f3dc355f612e0a07a Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Thu, 26 Dec 2024 16:26:03 +0800 Subject: [PATCH] [XPU] supports more dtypes of expand/reduce_sum/max/min (#70428) * [XPU] suports more dtypes of expand/reduce_sum/max/min * fix * fix * fix * fix * fix * fix * fix * fix --- paddle/phi/backends/xpu/xpu3_op_list.cc | 27 +++++-- paddle/phi/kernels/reduce_sum_kernel.cc | 3 +- paddle/phi/kernels/xpu/activation_kernel.cc | 35 ++++++++ paddle/phi/kernels/xpu/expand_as_kernel.cc | 1 + paddle/phi/kernels/xpu/expand_kernel.cc | 1 + .../phi/kernels/xpu/reduce_sum_grad_kernel.cc | 28 ++++++- paddle/phi/kernels/xpu/reduce_sum_kernel.cc | 3 +- test/legacy_test/test_minimum_op.py | 3 + test/xpu/test_activation_op_xpu.py | 79 +++++++++++++++++++ 9 files changed, 167 insertions(+), 13 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 68e3037257a481..88e042994d5a6e 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -339,19 +339,25 @@ XPUOpMap& get_kl3_ops() { {"elementwise_floordiv", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::BFLOAT16})}, + phi::DataType::BFLOAT16, + phi::DataType::INT64, + phi::DataType::INT32})}, {"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::BFLOAT16})}, + phi::DataType::BFLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_min_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::BFLOAT16})}, + phi::DataType::BFLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -421,20 +427,23 @@ XPUOpMap& get_kl3_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::BFLOAT16})}, + {"round", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"expand_as_v2", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, phi::DataType::BOOL, phi::DataType::BFLOAT16, phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64})}, {"expand_v2", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, phi::DataType::BOOL, phi::DataType::FLOAT16, phi::DataType::FLOAT32, - phi::DataType::BFLOAT16})}, + phi::DataType::BFLOAT16, + phi::DataType::FLOAT64})}, {"fast_where_xpu", XPUKernelSet({phi::DataType::INT32, phi::DataType::FLOAT32, @@ -946,14 +955,18 @@ XPUOpMap& get_kl3_ops() { {"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::BFLOAT16})}, + phi::DataType::BFLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL})}, {"reduce_sum", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::INT64, phi::DataType::INT32, phi::DataType::INT8, phi::DataType::FLOAT32, - phi::DataType::BFLOAT16})}, + phi::DataType::BFLOAT16, + phi::DataType::BOOL})}, {"relu6", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"relu_grad", diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index 2975e272c1e6d4..654eae919905fe 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -100,7 +100,8 @@ PD_REGISTER_KERNEL(sum, phi::dtype::bfloat16, int8_t, int, - int64_t) { + int64_t, + bool) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } #endif diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 24be07e64f144a..60f7629c8c3ebd 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -190,6 +190,27 @@ struct XPULeakyReluFunctor : public funcs::BaseActivationFunctor { } }; +template +struct XPURoundFunctor : public funcs::BaseActivationFunctor { + int decimals; + std::vector> GetAttrs() { + return {{"decimals", &decimals}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + using XPUType = typename XPUTypeTrait::Type; + int r = xpu::round(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x.numel(), + decimals); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "round"); + } +}; + template void PowKernel(const Context& dev_ctx, const DenseTensor& x, @@ -580,6 +601,17 @@ void HardSwishKernel(const Context& dev_ctx, dev_ctx, x, out, functor); } +template +void RoundKernel(const Context& dev_ctx, + const DenseTensor& x, + const int decimals, + DenseTensor* out) { + XPURoundFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = decimals; + ActivationXPUImpl>(dev_ctx, x, out, functor); +} + } // namespace phi PD_REGISTER_KERNEL(relu, @@ -694,6 +726,9 @@ PD_REGISTER_KERNEL(exp, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL( + round, XPU, ALL_LAYOUT, phi::RoundKernel, float, phi::dtype::float16) {} + #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} diff --git a/paddle/phi/kernels/xpu/expand_as_kernel.cc b/paddle/phi/kernels/xpu/expand_as_kernel.cc index b49394743a1165..9c27ce95c02dd7 100644 --- a/paddle/phi/kernels/xpu/expand_as_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_as_kernel.cc @@ -126,6 +126,7 @@ PD_REGISTER_KERNEL(expand_as, XPU, ALL_LAYOUT, phi::ExpandAsKernel, + double, float, phi::dtype::bfloat16, phi::dtype::float16, diff --git a/paddle/phi/kernels/xpu/expand_kernel.cc b/paddle/phi/kernels/xpu/expand_kernel.cc index dead87d33c514e..094527b027c5e2 100644 --- a/paddle/phi/kernels/xpu/expand_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_kernel.cc @@ -128,6 +128,7 @@ PD_REGISTER_KERNEL(expand, XPU, ALL_LAYOUT, phi::ExpandKernel, + double, float, phi::dtype::float16, bool, diff --git a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc index 0bce6c727e92ac..15bff0dc8e4577 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc @@ -14,8 +14,11 @@ #include "paddle/phi/kernels/reduce_sum_grad_kernel.h" #include + #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" namespace phi { @@ -63,9 +66,23 @@ void ReduceSumGradKernel(const Context& dev_ctx, ydims = std::vector({1}); } - int r = xpu::broadcast( - dev_ctx.x_context(), out_data, x_grad_data, ydims, xdims); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); + if (x.dtype() != out_grad.dtype()) { + DenseTensorMeta x_grad_meta( + out_grad.dtype(), x_grad->dims(), x_grad->layout()); + DenseTensor x_grad_tmp = + phi::Empty(dev_ctx, std::move(x_grad_meta)); + auto* x_grad_tmp_data = reinterpret_cast(x_grad_tmp.data()); + + int r = xpu::broadcast( + dev_ctx.x_context(), out_data, x_grad_tmp_data, ydims, xdims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); + + phi::CastKernel(dev_ctx, x_grad_tmp, x.dtype(), x_grad); + } else { + int r = xpu::broadcast( + dev_ctx.x_context(), out_data, x_grad_data, ydims, xdims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); + } } } // namespace phi @@ -76,6 +93,9 @@ PD_REGISTER_KERNEL(sum_grad, phi::ReduceSumGradKernel, float, phi::dtype::float16, - phi::dtype::bfloat16) { + phi::dtype::bfloat16, + int64_t, + int, + bool) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc index 6d6def595e74a4..47d2e27c4cc2b3 100644 --- a/paddle/phi/kernels/xpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_sum_kernel.cc @@ -46,6 +46,7 @@ PD_REGISTER_KERNEL(sum_raw, phi::dtype::bfloat16, int8_t, int, - int64_t) { + int64_t, + bool) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } diff --git a/test/legacy_test/test_minimum_op.py b/test/legacy_test/test_minimum_op.py index 82ebbad5d1410d..cd60d7e0f6e2d0 100644 --- a/test/legacy_test/test_minimum_op.py +++ b/test/legacy_test/test_minimum_op.py @@ -93,6 +93,9 @@ def test_static_api(self): ) np.testing.assert_allclose(res, self.np_expected4, rtol=1e-05) + @unittest.skipIf( + core.is_compiled_with_xpu(), "XPU int64_t minimum has bug now" + ) def test_dynamic_api(self): paddle.disable_static() x = paddle.to_tensor(self.input_x) diff --git a/test/xpu/test_activation_op_xpu.py b/test/xpu/test_activation_op_xpu.py index e64dd7abd1839d..5a9b09f4e72176 100644 --- a/test/xpu/test_activation_op_xpu.py +++ b/test/xpu/test_activation_op_xpu.py @@ -15,6 +15,7 @@ import os import unittest +from contextlib import contextmanager import numpy as np from get_test_cover_info import ( @@ -32,6 +33,15 @@ paddle.enable_static() +@contextmanager +def dynamic_guard(): + paddle.disable_static() + try: + yield + finally: + paddle.enable_static() + + class TestActivationOPBase(XPUOpTest): def setUp(self): self.place = paddle.XPUPlace(0) @@ -90,6 +100,75 @@ def set_shape(self): create_test_class(globals(), XPUTestExpOP, stype) +class XPUTestRoundOP(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'round' + self.use_dynamic_create_class = False + + class XPUTestRound(TestActivationOPBase): + def set_case(self): + self.op_type = 'round' + + self.init_dtype() + self.set_shape() + self.set_decimals() + + np.random.seed(1024) + x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) * 100 + out = np.round(x, decimals=self.decimals) + + self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} + self.outputs = {'Out': out} + self.attrs = {'decimals': self.decimals} + self.convert_input_output() + + def set_shape(self): + self.shape = [10, 12] + + def set_decimals(self): + self.decimals = 0 + + def test_check_grad(self): + pass + + def convert_input_output(self): + if self.dtype == np.uint16: + self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])} + self.outputs = { + 'Out': convert_float_to_uint16(self.outputs['Out']) + } + + class XPUTestRound_ZeroDIm(XPUTestRound): + def set_shape(self): + self.shape = [] + + class XPUTestRound_decimals1(XPUTestRound): + def init_decimals(self): + self.decimals = 2 + + def test_round_api(self): + with dynamic_guard(): + x_np = ( + np.random.uniform(-1, 1, self.shape).astype(self.dtype) + * 100 + ) + out_expect = np.round(x_np, decimals=self.decimals) + x_paddle = paddle.to_tensor( + x_np, dtype=self.dtype, place=self.place + ) + y = paddle.round(x_paddle, decimals=self.decimals) + np.testing.assert_allclose(y.numpy(), out_expect, rtol=1e-3) + + class TestRound_decimals2(XPUTestRound_decimals1): + def init_decimals(self): + self.decimals = -1 + + +support_types = get_xpu_op_support_types('round') +for stype in support_types: + create_test_class(globals(), XPUTestRoundOP, stype) + + class XPUTestSiluOP(XPUOpTestWrapper): def __init__(self): self.op_name = 'silu'