Skip to content

Commit

Permalink
init consts refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mbencer committed Dec 11, 2024
1 parent 1cc0ad9 commit 577a67f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
7 changes: 7 additions & 0 deletions runtime/onert/backend/cpu/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "ir/OperandIndexMap.h"
#include "ir/OperandIndexSequence.h"
#include "backend/basic/BackendContextHelpers.h"
#include "backend/basic/TensorRegistry.h"

namespace onert
{
Expand All @@ -44,6 +45,12 @@ FunctionMap BackendContext::genKernels()
basic::initConsts(graph()->operands(), external_operands(), tensor_registry.get(),
tensor_builder->getSharedMemoryOperandIndexes());

// TODO: Change type of tensor_registry field to TensorRegistry
auto tensor_registry_concreted = dynamic_cast<basic::TensorRegistry *>(tensor_registry.get());
assert(tensor_registry_concreted);
basic::initSharedMemoryConsts(graph()->operands(), external_operands(), tensor_registry_concreted,
tensor_builder->getSharedMemoryOperandIndexes());

for (auto &&op_ind : _data.op_order)
{
auto fn_seq = kernel_gen->generate(op_ind);
Expand Down
56 changes: 35 additions & 21 deletions runtime/onert/core/include/backend/basic/BackendContextHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "util/logging.h"
#include "backend/ITensorRegistry.h"
#include "backend/BackendContext.h"
#include "backend/basic/TensorRegistry.h"
#include "Tensor.h"

namespace onert
Expand Down Expand Up @@ -240,45 +241,58 @@ template <typename T_BackendContext> ITensorRegistry *genTensors(T_BackendContex
ctx.data().op_order, {});
}

inline void
initSharedMemoryConsts(const ir::Operands &operands,
const util::Set<ir::OperandIndex> &external_operands,
TensorRegistry *tensor_registry,
const ir::OperandIndexMap<ir::OperandIndex> &shared_memory_operands_map)
{
operands.iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
if (external_operands.contains(ind))
return;
const auto shared_mem_it = shared_memory_operands_map.find(ind);
if (shared_mem_it == std::end(shared_memory_operands_map))
return; // no shared memory source
if (!operands.at(shared_mem_it->second).isConstant())
return; // source operand not a constant

VERBOSE(FillOperandData) << "Fill shared data for " << ind << std::endl;

const auto &source_operand_ind = operands.at(shared_mem_it->second);
auto memory_source_data = source_operand_ind.shareData();
assert(memory_source_data && memory_source_data->base());
auto tensor = tensor_registry->getNativeTensor(ind);
assert(tensor != nullptr);
tensor->setBuffer(const_cast<uint8_t *>(memory_source_data->base()));
});
}

inline void initConsts(const ir::Operands &operands,
const util::Set<ir::OperandIndex> &external_operands,
ITensorRegistry *tensor_registry,
const ir::OperandIndexMap<ir::OperandIndex> &shared_memory_operands_map)
{
operands.iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) {
const bool has_const_shared_memory =
if (external_operands.contains(ind) || !operand.isConstant())
return;
const bool has_const_shared_source =
shared_memory_operands_map.find(ind) != std::end(shared_memory_operands_map) &&
operands.at(shared_memory_operands_map.at(ind)).isConstant();
if (external_operands.contains(ind))
return;
const bool can_be_initialized_as_const = operand.isConstant() || has_const_shared_memory;
if (!can_be_initialized_as_const)
// tensor currently processed not a const and source memory tensor (if exists) not a const too
return;
if (has_const_shared_source)
return; // tensors with shared memory are processed in initSharedMemoryConsts

auto tensor = tensor_registry->getNativeITensor(ind);
assert(tensor != nullptr);

VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;

if (has_const_shared_memory)
{
const auto &source_operand_ind = operands.at(shared_memory_operands_map.at(ind));
auto memory_source_data = source_operand_ind.shareData();
assert(memory_source_data && memory_source_data->base());
auto shared_mem_tensor = dynamic_cast<Tensor *>(tensor);
assert(shared_mem_tensor != nullptr);
shared_mem_tensor->setBuffer(const_cast<uint8_t *>(memory_source_data->base()));
return;
}
// the default flow for constant initialization
auto data = operand.shareData();
assert(data && data->base());
auto ext_tensor = dynamic_cast<ExternalTensor *>(tensor);
ExternalTensor *ext_tensor = dynamic_cast<ExternalTensor *>(tensor);

if (ext_tensor == nullptr)
{
throw std::runtime_error{"This tensor is not external tensor"};
}

ext_tensor->setData(data);
});
}
Expand Down

0 comments on commit 577a67f

Please sign in to comment.