diff --git a/runtime/onert/backend/cpu/BackendContext.cc b/runtime/onert/backend/cpu/BackendContext.cc index 95635152a9a..cf199a7fe32 100644 --- a/runtime/onert/backend/cpu/BackendContext.cc +++ b/runtime/onert/backend/cpu/BackendContext.cc @@ -23,6 +23,7 @@ #include "ir/OperandIndexMap.h" #include "ir/OperandIndexSequence.h" #include "backend/basic/BackendContextHelpers.h" +#include "backend/basic/TensorRegistry.h" namespace onert { @@ -44,6 +45,12 @@ FunctionMap BackendContext::genKernels() basic::initConsts(graph()->operands(), external_operands(), tensor_registry.get(), tensor_builder->getSharedMemoryOperandIndexes()); + // TODO: Change type of tensor_registry field to TensorRegistry + auto tensor_registry_concreted = dynamic_cast(tensor_registry.get()); + assert(tensor_registry_concreted); + basic::initSharedMemoryConsts(graph()->operands(), external_operands(), tensor_registry_concreted, + tensor_builder->getSharedMemoryOperandIndexes()); + for (auto &&op_ind : _data.op_order) { auto fn_seq = kernel_gen->generate(op_ind); diff --git a/runtime/onert/core/include/backend/basic/BackendContextHelpers.h b/runtime/onert/core/include/backend/basic/BackendContextHelpers.h index c95dd9b7277..4fec186fd8c 100644 --- a/runtime/onert/core/include/backend/basic/BackendContextHelpers.h +++ b/runtime/onert/core/include/backend/basic/BackendContextHelpers.h @@ -24,6 +24,7 @@ #include "util/logging.h" #include "backend/ITensorRegistry.h" #include "backend/BackendContext.h" +#include "backend/basic/TensorRegistry.h" #include "Tensor.h" namespace onert @@ -240,45 +241,58 @@ template ITensorRegistry *genTensors(T_BackendContex ctx.data().op_order, {}); } +inline void +initSharedMemoryConsts(const ir::Operands &operands, + const util::Set &external_operands, + TensorRegistry *tensor_registry, + const ir::OperandIndexMap &shared_memory_operands_map) +{ + operands.iterate([&](const ir::OperandIndex &ind, const ir::Operand &) { + if (external_operands.contains(ind)) + return; + const auto shared_mem_it = shared_memory_operands_map.find(ind); + if (shared_mem_it == std::end(shared_memory_operands_map)) + return; // no shared memory source + if (!operands.at(shared_mem_it->second).isConstant()) + return; // source operand not a constant + + VERBOSE(FillOperandData) << "Fill shared data for " << ind << std::endl; + + const auto &source_operand_ind = operands.at(shared_mem_it->second); + auto memory_source_data = source_operand_ind.shareData(); + assert(memory_source_data && memory_source_data->base()); + auto tensor = tensor_registry->getNativeTensor(ind); + assert(tensor != nullptr); + tensor->setBuffer(const_cast(memory_source_data->base())); + }); +} + inline void initConsts(const ir::Operands &operands, const util::Set &external_operands, ITensorRegistry *tensor_registry, const ir::OperandIndexMap &shared_memory_operands_map) { operands.iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) { - const bool has_const_shared_memory = + if (external_operands.contains(ind) || !operand.isConstant()) + return; + const bool has_const_shared_source = shared_memory_operands_map.find(ind) != std::end(shared_memory_operands_map) && operands.at(shared_memory_operands_map.at(ind)).isConstant(); - if (external_operands.contains(ind)) - return; - const bool can_be_initialized_as_const = operand.isConstant() || has_const_shared_memory; - if (!can_be_initialized_as_const) - // tensor currently processed not a const and source memory tensor (if exists) not a const too - return; + if (has_const_shared_source) + return; // tensors with shared memory are processed in initSharedMemoryConsts auto tensor = tensor_registry->getNativeITensor(ind); assert(tensor != nullptr); VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl; - if (has_const_shared_memory) - { - const auto &source_operand_ind = operands.at(shared_memory_operands_map.at(ind)); - auto memory_source_data = source_operand_ind.shareData(); - assert(memory_source_data && memory_source_data->base()); - auto shared_mem_tensor = dynamic_cast(tensor); - assert(shared_mem_tensor != nullptr); - shared_mem_tensor->setBuffer(const_cast(memory_source_data->base())); - return; - } - // the default flow for constant initialization auto data = operand.shareData(); assert(data && data->base()); - auto ext_tensor = dynamic_cast(tensor); + ExternalTensor *ext_tensor = dynamic_cast(tensor); + if (ext_tensor == nullptr) - { throw std::runtime_error{"This tensor is not external tensor"}; - } + ext_tensor->setData(data); }); }