Skip to content

Commit

Permalink
compilation optimization for lerp_grad_kernel (PaddlePaddle#57821)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd authored Sep 28, 2023
1 parent 20b9713 commit e497279
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions paddle/phi/kernels/gpu/lerp_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {

Expand Down Expand Up @@ -248,9 +249,8 @@ void LerpGradKernel(const Context& ctx,
b_xgrad.dims(),
-1);
if (!reduce_axis_x.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, b_xgrad, x_grad, kps::IdentityFunctor<T>(), reduce_axis_x);
phi::SumKernel<T, Context>(
ctx, b_xgrad, reduce_axis_x, b_xgrad.dtype(), false, x_grad);
} else {
x_grad->ShareDataWith(b_xgrad);
}
Expand All @@ -262,9 +262,8 @@ void LerpGradKernel(const Context& ctx,
b_ygrad.dims(),
-1);
if (!reduce_axis_y.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, b_ygrad, y_grad, kps::IdentityFunctor<T>(), reduce_axis_y);
phi::SumKernel<T, Context>(
ctx, b_ygrad, reduce_axis_y, b_ygrad.dtype(), false, y_grad);
} else {
y_grad->ShareDataWith(b_ygrad);
}
Expand Down

0 comments on commit e497279

Please sign in to comment.