Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve new executor static build #51149

Merged
merged 57 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
d9c535f
Improve new executor static build
From00 Mar 3, 2023
bf50fb7
Skip GC for static build
From00 Mar 5, 2023
091e96e
Skip infershape for static build
From00 Mar 6, 2023
4d418b8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 6, 2023
e2c53b6
Handle read_op
From00 Mar 7, 2023
432e143
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 8, 2023
9bd46a1
Add fused_attention to OpsWithFluidKernelNeedMoveToPhi
From00 Mar 9, 2023
6db222a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 9, 2023
7c710bf
Fix argsort typos
From00 Mar 10, 2023
c6a115c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 10, 2023
c0642ce
Add sequence_pool to OpsWithFluidKernelNeedMoveToPhi
From00 Mar 12, 2023
7669682
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 12, 2023
696122d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 12, 2023
3723697
Fix skip share lod errors
From00 Mar 12, 2023
205906d
Fix errors for adam
From00 Mar 14, 2023
6840f5b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 14, 2023
405ac93
Fix errors for eigvals, memcpy and fake_quantize
From00 Mar 18, 2023
5004e56
Add static_build.cc
From00 Mar 18, 2023
74729bc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 18, 2023
fc92a6a
Add black list
From00 Mar 19, 2023
e9ee0bc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 19, 2023
cafc9c6
Fix CI errors
From00 Mar 21, 2023
46a4dc9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 21, 2023
c2ef1d3
Fix CI errors
From00 Mar 21, 2023
002b338
Fix CI errors
From00 Mar 23, 2023
f0891eb
Fix CI errors
From00 Mar 23, 2023
9132af3
Fix TensorArray
From00 Mar 24, 2023
4fa2713
Fix TensorArray
From00 Mar 24, 2023
5e1fae3
Add update_loss_scaling to OpsNeedSetOutputDtypeWhenRegisterPhiKernel
From00 Mar 25, 2023
718a283
Fix copy
From00 Mar 25, 2023
76132e6
Fix errors
From00 Mar 25, 2023
f28c1c1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 26, 2023
89f4283
Fix momentum
From00 Mar 26, 2023
a79e298
Skip mkldnn
From00 Mar 26, 2023
0e2b3fe
Fix CI errors
From00 Mar 27, 2023
307a445
Fix c_sync_calc_stream_op
From00 Mar 28, 2023
e5bd4cd
Fix CINN
From00 Mar 30, 2023
fa8e04f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Mar 30, 2023
886433e
Fix while op
From00 Mar 31, 2023
f456101
All CI pass, disable FLAGS to merge code, enable it after more tests …
From00 Apr 1, 2023
94c9809
Add UTs
From00 Apr 1, 2023
c9a1f3d
Fix typos
From00 Apr 1, 2023
aa76cbb
Fix typos
From00 Apr 1, 2023
1eb7fd8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
From00 Apr 1, 2023
a56856d
Add mkldnn UT
From00 Apr 1, 2023
b4a6077
Remove mkldnn test
From00 Apr 1, 2023
62f0e7b
Fix typos
From00 Apr 1, 2023
63babbd
Fix dist test
From00 Apr 1, 2023
ed4b989
Fix typos
From00 Apr 1, 2023
986ad7a
Fix CI errors
From00 Apr 2, 2023
99a4add
Fix CI errors
From00 Apr 2, 2023
4934047
Add UTs
From00 Apr 2, 2023
4c76d00
Fix typos
From00 Apr 2, 2023
10c47fb
Fix typos
From00 Apr 2, 2023
d5a67ea
Add sparse tests
From00 Apr 3, 2023
b5fa80d
ToComplexType -> ToComplex
From00 Apr 3, 2023
606f016
Add test_matmul_op_static_build to disable_win_inference_test
From00 Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
set(INTERPRETER_SRCS data_transfer.cc dependency_builder.cc execution_config.cc
interpreter_util.cc stream_analyzer.cc)
interpreter_util.cc static_build.cc stream_analyzer.cc)

set(INTERPRETER_DEPS
buffered_reader
device_context
global_utils
op_registry
phi_tensor_utils
scope
framework_proto
data_feed_proto
Expand Down
41 changes: 25 additions & 16 deletions paddle/fluid/framework/new_executor/interpreter/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"

Expand All @@ -37,7 +38,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope,
bool is_fetch_v2,
bool skip_run) {
bool static_build) {
bool is_transferred = false;
auto* src_var_name = &var_name;

Expand All @@ -52,7 +53,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
is_fetch_v2);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
op, *src_var_name, *new_var_name, op_func_nodes, static_build);
}
// update src_var_name
src_var_name = new_var_name;
Expand All @@ -70,7 +71,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
scope_);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
op, *src_var_name, *new_var_name, op_func_nodes, static_build);
}
// update src_var_name
src_var_name = new_var_name;
Expand All @@ -87,7 +88,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
*src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
op, *src_var_name, *new_var_name, op_func_nodes, static_build);
}
is_transferred = true;
}
Expand All @@ -98,7 +99,7 @@ void DataTranferHelper::RunAndConstructShareNode(
const std::string& src_var_name,
const std::string& dst_var_name,
std::vector<OpFuncNode>* op_func_nodes,
bool skip_run) {
bool static_build) {
VariableNameMap in_name_map = {{"X", {src_var_name}}};
VariableNameMap out_name_map = {{"Out", {dst_var_name}}};
AttributeMap attr_map;
Expand All @@ -112,23 +113,26 @@ void DataTranferHelper::RunAndConstructShareNode(
"Insert %s with %s -> %s.", op_type, src_var_name, dst_var_name);

RunAndConstructOpFuncNode(
op, src_var_name, dst_var_name, op_func_nodes, skip_run);
op, src_var_name, dst_var_name, op_func_nodes, static_build);
}

void DataTranferHelper::RunAndConstructOpFuncNode(
const std::shared_ptr<OperatorBase>& op,
const std::string& var_name,
const std::string& new_var_name,
std::vector<OpFuncNode>* new_op_func_nodes,
bool skip_run) {
bool static_build) {
auto& op_type = op->Type();

// 1. Construct RuntimeContext
RuntimeContext runtime_context({}, {});
runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
op.get()->Info().infer_shape_(&infer_shape_ctx);

if (!static_build) {
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
op->Info().infer_shape_(&infer_shape_ctx);
}

// 2. choose kernel

Expand Down Expand Up @@ -203,8 +207,9 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
} else {
new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();

if (skip_run) {
if (static_build) {
FakeInitializeOutputsForFunctionKernel(
*op,
*(new_op_func_node.phi_kernel_),
*(op_with_kernel->PhiKernelSignature()),
runtime_context,
Expand Down Expand Up @@ -449,7 +454,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
OpFuncNode* op_func_node,
std::vector<OpFuncNode>* new_op_func_nodes,
bool use_local_scope,
bool skip_run) {
bool static_build) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();

Expand Down Expand Up @@ -546,7 +551,11 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
op_base->Type() == "fetch_v2");
if (op) {
data_transfer_helper.RunAndConstructOpFuncNode(
op, var_name, new_var_name, new_op_func_nodes, skip_run);
op,
var_name,
new_var_name,
new_op_func_nodes,
static_build);
}
is_transferred = true;
} else {
Expand Down Expand Up @@ -611,7 +620,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
new_op_func_nodes,
use_local_scope,
op_base->Type() == "fetch_v2",
skip_run);
static_build);
}

if (is_transferred) {
Expand Down Expand Up @@ -741,7 +750,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope,
bool skip_run) {
bool static_build) {
DataTranferHelper data_transfer_helper(place, var_scope, local_scope);
for (auto& var_name_item : out_names) {
std::vector<Variable*>& vars = out_vars->at(var_name_item.first);
Expand Down Expand Up @@ -817,9 +826,9 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
auto op = TransferDtype(
var_name, &new_var_name, src_type, dst_type, var_scope, local_scope);
data_transfer_helper.RunAndConstructOpFuncNode(
op, var_name, new_var_name, op_func_nodes, skip_run);
op, var_name, new_var_name, op_func_nodes, static_build);
data_transfer_helper.RunAndConstructShareNode(
new_var_name, var_name, op_func_nodes, skip_run);
new_var_name, var_name, op_func_nodes, static_build);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ const std::string StringizeDownstreamMap(

const std::map<size_t, std::set<size_t>>& DependencyBuilder::Build(
const std::vector<Instruction>& instructions) {
PADDLE_ENFORCE_EQ(
is_build_,
false,
phi::errors::AlreadyExists("The op dependency has been built"));
if (is_build_) {
return op_downstream_map_;
}

instructions_ = &instructions;
op_num_ = instructions_->size();
Expand Down
Loading