Skip to content

Commit

Permalink
refactor(python_ffi): 修改从 exector 存取数据块的接口
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 12, 2024
1 parent a41a09f commit 54c2f7e
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/03runtime/include/runtime/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace refactor::runtime {
decltype(_graph) const &graph() const noexcept { return _graph; }
auto setData(count_t, size_t) -> Arc<hardware::Device::Blob>;
void setData(count_t, Arc<hardware::Device::Blob>);
auto getData(count_t) -> Arc<hardware::Device::Blob> const;
auto getData(count_t) const -> Arc<hardware::Device::Blob>;
void setData(count_t, void const *, size_t);
bool copyData(count_t, void *, size_t) const;
void run();
Expand Down
2 changes: 1 addition & 1 deletion src/03runtime/src/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace refactor::runtime {
blob->copyFromHost(data, size);
_graph.edges[i].blob = std::move(blob);
}
auto Stream::getData(count_t i) -> Arc<hardware::Device::Blob> const {
auto Stream::getData(count_t i) const -> Arc<hardware::Device::Blob> {
return _graph.edges[i].blob;
}
bool Stream::copyData(count_t i, void *data, size_t size) const {
Expand Down
26 changes: 12 additions & 14 deletions src/09python_ffi/src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ namespace refactor::python_ffi {
_stream.setData(i, data.data(), data.nbytes());
}

auto Executor::getOutput(count_t i) -> pybind11::array {
void Executor::setInputBlob(count_t i, Arc<hardware::Device::Blob> blob) {
i = _stream.graph().topology.globalInputs().at(i);

auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
ASSERT(tensor.bytesSize() == blob->size(), "input size mismatch");
_stream.setData(i, std::move(blob));
}

auto Executor::getOutput(count_t i) const -> pybind11::array {
i = _stream.graph().topology.globalOutputs().at(i);

auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
Expand All @@ -49,20 +57,10 @@ namespace refactor::python_ffi {
return ans;
}

auto Executor::pin(count_t i) -> Arc<hardware::Device::Blob> {
i = _stream.graph().topology.globalInputs().at(i);

if (auto pinned = _stream.getData(i); pinned) {
return pinned;
} else {
auto const &tensor = *_graph.internal().contiguous().edges[i].tensor;
return _stream.setData(i, tensor.bytesSize());
}
}
void Executor::setPinned(count_t i, Arc<hardware::Device::Blob> pinned) {
i = _stream.graph().topology.globalInputs().at(i);
auto Executor::getOutputBlob(count_t i) const -> Arc<hardware::Device::Blob> {
i = _stream.graph().topology.globalOutputs().at(i);

_stream.setData(i, std::move(pinned));
return _stream.getData(i);
}

void Executor::run() {
Expand Down
6 changes: 3 additions & 3 deletions src/09python_ffi/src/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace refactor::python_ffi {
Executor(computation::Graph, runtime::Stream);
void dispatch(Arc<hardware::Device>, std::string allocator);
void setInput(count_t, pybind11::array);
auto getOutput(count_t) -> pybind11::array;
auto pin(count_t) -> Arc<hardware::Device::Blob>;
void setPinned(count_t, Arc<hardware::Device::Blob>);
void setInputBlob(count_t, Arc<hardware::Device::Blob>);
auto getOutput(count_t) const -> pybind11::array;
auto getOutputBlob(count_t) const -> Arc<hardware::Device::Blob>;
void run();
void bench(bool sync);
void trace(std::string path, std::string format);
Expand Down
4 changes: 2 additions & 2 deletions src/09python_ffi/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ namespace refactor::python_ffi {
py::class_<Executor , Arc<Executor>>(m, "Executor" )
.def("dispatch" , &Executor::dispatch , return_::automatic )
.def("set_input" , &Executor::setInput , return_::automatic )
.def("set_input_blob" , &Executor::setInputBlob , return_::automatic )
.def("get_output" , &Executor::getOutput , return_::move )
.def("pin" , &Executor::pin , return_::move )
.def("set_pinned" , &Executor::setPinned , return_::automatic )
.def("get_output_blob" , &Executor::getOutputBlob , return_::move )
.def("run" , &Executor::run , return_::automatic )
.def("bench" , &Executor::bench , return_::automatic )
.def("trace" , &Executor::trace , return_::automatic )
Expand Down

0 comments on commit 54c2f7e

Please sign in to comment.