Skip to content

Commit

Permalink
static graph autogen code support for full_like op (#54698)
Browse files Browse the repository at this point in the history
* static graph autogen code support for full_like op

* fix

* fix bug
  • Loading branch information
GreatV authored Jun 19, 2023
1 parent 93f7a02 commit 8947488
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 130 deletions.
98 changes: 0 additions & 98 deletions paddle/fluid/operators/fill_any_like_op.cc

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ phi::KernelKey GetExpectedKernelType(
{% endif %}
{% elif kernel["data_type"]["candidates"] | length == 2 %}
{% set data_type_args = kernel["data_type"]["candidates"] %}
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_args[0]}}");
if (data_type == static_cast<proto::VarType::Type>(-1)) {
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_args[0]}}"));
if (data_type == static_cast<framework::proto::VarType::Type>(-1)) {
data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}});
}
{% endif %}
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,10 @@
x : X
outputs :
out : Out
attrs :
{value: value, dtype: dtype}
scalar :
value :
data_type : float
support_tensor : true

- op : fused_conv2d
extra :
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@
param : [x, axis, keepdim, reduce_all]
backward : frobenius_norm_grad

- op : full_like
args : (Tensor x, Scalar value = 0.0, DataType dtype = DataType::UNDEFINED)
output: Tensor(out)
infer_meta :
func : FillAnyLikeInferMeta
kernel :
func : full_like
param : [x, value, dtype]
data_type : dtype > x

- op : gaussian
args : (IntArray shape = {}, float mean = .0f, float std = 1.0f, int seed = 0, DataType dtype = DataType::FLOAT32)
output: Tensor(out)
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,15 @@ void ExpandInferMeta(const MetaTensor& x,
}
}

void FillAnyLikeInferMeta(const MetaTensor& x,
const Scalar& value,
DataType dtype,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
out->share_lod(x);
}

void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out) {
PADDLE_ENFORCE_NE(
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out);

void FillAnyLikeInferMeta(const MetaTensor& x,
const Scalar& value,
DataType dtype,
MetaTensor* out);

void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out);

Expand Down
28 changes: 0 additions & 28 deletions paddle/phi/ops/compat/fill_any_like_sig.cc

This file was deleted.

0 comments on commit 8947488

Please sign in to comment.