-
Notifications
You must be signed in to change notification settings - Fork 0
/
executor.cpp
99 lines (90 loc) · 3.28 KB
/
executor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include "executor.h"
namespace torch {
namespace executor {
Executor::Executor(const executorch::Program* program)
: program_(program), plan_(program) {}
int Executor::init_execution_plan(int index) {
auto serialization_plan = program_->execution_plan()->GetMutableObject(index);
return plan_.init(serialization_plan);
}
int ExecutionPlan::init(executorch::ExecutionPlan* s_plan) {
serialization_plan_ = s_plan;
// Load values
n_value_ = serialization_plan_->values()->size();
values_ = new Value[n_value_];
for (int i = 0; i < n_value_; ++i) {
auto serialization_value = serialization_plan_->values()->Get(i);
switch (serialization_value->val_type()) {
case executorch::ValueUnion::Int: {
values_[i].tag = Tag::Int;
values_[i].payload.as_int = serialization_value->val_as_Int()->int_val();
} break;
case executorch::ValueUnion::Tensor: {
values_[i].tag = Tag::Tensor;
auto s_tensor = serialization_value->val_as_Tensor();
// TODO: use placement new
Tensor *t = new Tensor(
static_cast<ScalarType>(s_tensor->scalar_type()),
s_tensor->sizes()->size(),
const_cast<int *>(
s_tensor->sizes()->data()));
if (s_tensor->buffer_index() > 0) { // 0 is reserved for RW data
auto buffer =
program_->buffers()->GetMutableObject(s_tensor->buffer_index());
t->data = static_cast<void *>(buffer->mutable_data()->data());
}
else { // TODO: init RW memory pools and do pointer mapping
t->data = new uint8_t[t->nbytes];
}
values_[i].payload.as_tensor = t;
} break;
default: // TODO: support all types
error_with_message("type not supported");
}
}
// Resolve operators
n_operator = serialization_plan_->operators()->size();
operators_ = new OpFunction[n_operator];
for (int i = 0; i < n_operator; ++i) {
std::string op_name(serialization_plan_->operators()->Get(i)->name()->str());
operators_[i] = getOpsFn(op_name);
}
// Load chains
auto chains = serialization_plan_->chains();
n_chains_ = chains->size();
chains_ = new Chain[n_chains_];
for (int i = 0; i < n_chains_; ++i) {
auto kernels = chains->Get(i)->kernels();
Chain* r_chain = &chains_[i]; // runtime chain
r_chain->n_kernels_ = kernels->size();
r_chain->kernels_ = new Kernel[r_chain->n_kernels_];
for (int j = 0; j < r_chain->n_kernels_; ++j) {
auto kernel = kernels->Get(j); // serialization kernel
Kernel* r_kernel = &r_chain->kernels_[j];
r_kernel->op_index_ = kernel->op_index();
auto args = kernel->args();
r_kernel->n_args_ = args->size();
r_kernel->args_ = new Value[r_kernel->n_args_];
for (int k = 0; k < r_kernel->n_args_; ++k) {
r_kernel->args_[k] = values_[args->Get(k)];
}
}
}
return 0;
}
int ExecutionPlan::execute() const {
// V0: execute chains sequentially.
// TODO: execute them in patterns based on (possible) control flow, delegate or async.
// chain loo;
for (int i = 0; i < n_chains_; ++i) {
Chain* chain = &chains_[i];
// kernel loop
for (int j = 0; j < chain->n_kernels_; ++j) {
Kernel* kernel = &chain->kernels_[j];
operators_[kernel->op_index_](kernel->args_);
}
}
return 0;
}
} // namespace executor
} // namespace torch