Skip to content

Commit

Permalink
[XPU] supports more dtypes of expand/reduce_sum/max/min (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#70428)

* [XPU] suports more dtypes of expand/reduce_sum/max/min

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
dynamicheart authored Dec 26, 2024
1 parent 1e9f2b3 commit e3fb7d6
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 13 deletions.
27 changes: 20 additions & 7 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,19 +339,25 @@ XPUOpMap& get_kl3_ops() {
{"elementwise_floordiv",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32})},
{"elementwise_max_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_max",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_min_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_min",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_mul_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
Expand Down Expand Up @@ -421,20 +427,23 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"round", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"expand_as_v2",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::FLOAT64})},
{"expand_v2",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::BFLOAT16})},
phi::DataType::BFLOAT16,
phi::DataType::FLOAT64})},
{"fast_where_xpu",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::FLOAT32,
Expand Down Expand Up @@ -946,14 +955,18 @@ XPUOpMap& get_kl3_ops() {
{"reduce_sum_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"reduce_sum",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT32,
phi::DataType::BFLOAT16})},
phi::DataType::BFLOAT16,
phi::DataType::BOOL})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad",
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/reduce_sum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ PD_REGISTER_KERNEL(sum,
phi::dtype::bfloat16,
int8_t,
int,
int64_t) {
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#endif
35 changes: 35 additions & 0 deletions paddle/phi/kernels/xpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ struct XPULeakyReluFunctor : public funcs::BaseActivationFunctor<T> {
}
};

template <typename T>
struct XPURoundFunctor : public funcs::BaseActivationFunctor<T> {
int decimals;
std::vector<std::pair<const char*, int*>> GetAttrs() {
return {{"decimals", &decimals}};
}

template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
using XPUType = typename XPUTypeTrait<T>::Type;
int r = xpu::round<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel(),
decimals);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "round");
}
};

template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -580,6 +601,17 @@ void HardSwishKernel(const Context& dev_ctx,
dev_ctx, x, out, functor);
}

template <typename T, typename Context>
void RoundKernel(const Context& dev_ctx,
const DenseTensor& x,
const int decimals,
DenseTensor* out) {
XPURoundFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = decimals;
ActivationXPUImpl<T, Context, XPURoundFunctor<T>>(dev_ctx, x, out, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(relu,
Expand Down Expand Up @@ -694,6 +726,9 @@ PD_REGISTER_KERNEL(exp,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(
round, XPU, ALL_LAYOUT, phi::RoundKernel, float, phi::dtype::float16) {}

#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/expand_as_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ PD_REGISTER_KERNEL(expand_as,
XPU,
ALL_LAYOUT,
phi::ExpandAsKernel,
double,
float,
phi::dtype::bfloat16,
phi::dtype::float16,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/expand_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ PD_REGISTER_KERNEL(expand,
XPU,
ALL_LAYOUT,
phi::ExpandKernel,
double,
float,
phi::dtype::float16,
bool,
Expand Down
28 changes: 24 additions & 4 deletions paddle/phi/kernels/xpu/reduce_sum_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
#include "paddle/phi/kernels/reduce_sum_grad_kernel.h"

#include <set>

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"

namespace phi {

Expand Down Expand Up @@ -63,9 +66,23 @@ void ReduceSumGradKernel(const Context& dev_ctx,
ydims = std::vector<int>({1});
}

int r = xpu::broadcast<XPUType>(
dev_ctx.x_context(), out_data, x_grad_data, ydims, xdims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
if (x.dtype() != out_grad.dtype()) {
DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
DenseTensor x_grad_tmp =
phi::Empty<Context>(dev_ctx, std::move(x_grad_meta));
auto* x_grad_tmp_data = reinterpret_cast<XPUType*>(x_grad_tmp.data());

int r = xpu::broadcast<XPUType>(
dev_ctx.x_context(), out_data, x_grad_tmp_data, ydims, xdims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");

phi::CastKernel<T>(dev_ctx, x_grad_tmp, x.dtype(), x_grad);
} else {
int r = xpu::broadcast<XPUType>(
dev_ctx.x_context(), out_data, x_grad_data, ydims, xdims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
}
}

} // namespace phi
Expand All @@ -76,6 +93,9 @@ PD_REGISTER_KERNEL(sum_grad,
phi::ReduceSumGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {
phi::dtype::bfloat16,
int64_t,
int,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/reduce_sum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ PD_REGISTER_KERNEL(sum_raw,
phi::dtype::bfloat16,
int8_t,
int,
int64_t) {
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
3 changes: 3 additions & 0 deletions test/legacy_test/test_minimum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def test_static_api(self):
)
np.testing.assert_allclose(res, self.np_expected4, rtol=1e-05)

@unittest.skipIf(
core.is_compiled_with_xpu(), "XPU int64_t minimum has bug now"
)
def test_dynamic_api(self):
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
Expand Down
79 changes: 79 additions & 0 deletions test/xpu/test_activation_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import os
import unittest
from contextlib import contextmanager

import numpy as np
from get_test_cover_info import (
Expand All @@ -32,6 +33,15 @@
paddle.enable_static()


@contextmanager
def dynamic_guard():
paddle.disable_static()
try:
yield
finally:
paddle.enable_static()


class TestActivationOPBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
Expand Down Expand Up @@ -90,6 +100,75 @@ def set_shape(self):
create_test_class(globals(), XPUTestExpOP, stype)


class XPUTestRoundOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'round'
self.use_dynamic_create_class = False

class XPUTestRound(TestActivationOPBase):
def set_case(self):
self.op_type = 'round'

self.init_dtype()
self.set_shape()
self.set_decimals()

np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) * 100
out = np.round(x, decimals=self.decimals)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {'decimals': self.decimals}
self.convert_input_output()

def set_shape(self):
self.shape = [10, 12]

def set_decimals(self):
self.decimals = 0

def test_check_grad(self):
pass

def convert_input_output(self):
if self.dtype == np.uint16:
self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])}
self.outputs = {
'Out': convert_float_to_uint16(self.outputs['Out'])
}

class XPUTestRound_ZeroDIm(XPUTestRound):
def set_shape(self):
self.shape = []

class XPUTestRound_decimals1(XPUTestRound):
def init_decimals(self):
self.decimals = 2

def test_round_api(self):
with dynamic_guard():
x_np = (
np.random.uniform(-1, 1, self.shape).astype(self.dtype)
* 100
)
out_expect = np.round(x_np, decimals=self.decimals)
x_paddle = paddle.to_tensor(
x_np, dtype=self.dtype, place=self.place
)
y = paddle.round(x_paddle, decimals=self.decimals)
np.testing.assert_allclose(y.numpy(), out_expect, rtol=1e-3)

class TestRound_decimals2(XPUTestRound_decimals1):
def init_decimals(self):
self.decimals = -1


support_types = get_xpu_op_support_types('round')
for stype in support_types:
create_test_class(globals(), XPUTestRoundOP, stype)


class XPUTestSiluOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'silu'
Expand Down

0 comments on commit e3fb7d6

Please sign in to comment.