From e49727935aa66990aafcb4bd781a23552a949516 Mon Sep 17 00:00:00 2001 From: tianhaodongbd <137985359+tianhaodongbd@users.noreply.github.com> Date: Thu, 28 Sep 2023 20:14:14 +0800 Subject: [PATCH] compilation optimization for lerp_grad_kernel (#57821) --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index 43cf0deab6dd9d..d18c769b5117d0 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -25,6 +25,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" namespace phi { @@ -248,9 +249,8 @@ void LerpGradKernel(const Context& ctx, b_xgrad.dims(), -1); if (!reduce_axis_x.empty()) { - phi::funcs:: - ReduceKernel>( - ctx, b_xgrad, x_grad, kps::IdentityFunctor(), reduce_axis_x); + phi::SumKernel( + ctx, b_xgrad, reduce_axis_x, b_xgrad.dtype(), false, x_grad); } else { x_grad->ShareDataWith(b_xgrad); } @@ -262,9 +262,8 @@ void LerpGradKernel(const Context& ctx, b_ygrad.dims(), -1); if (!reduce_axis_y.empty()) { - phi::funcs:: - ReduceKernel>( - ctx, b_ygrad, y_grad, kps::IdentityFunctor(), reduce_axis_y); + phi::SumKernel( + ctx, b_ygrad, reduce_axis_y, b_ygrad.dtype(), false, y_grad); } else { y_grad->ShareDataWith(b_ygrad); }