From 5d536b29d7ce437aad79ff046788735a33fb1dbe Mon Sep 17 00:00:00 2001 From: axmat Date: Sun, 22 Aug 2021 16:51:41 +0200 Subject: [PATCH] Add support for operators with multiple outputs --- tmva/sofie/src/RModel.cxx | 53 ++++++++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/tmva/sofie/src/RModel.cxx b/tmva/sofie/src/RModel.cxx index 54d0cfe715e65..e65e9b773b776 100644 --- a/tmva/sofie/src/RModel.cxx +++ b/tmva/sofie/src/RModel.cxx @@ -15,6 +15,7 @@ namespace SOFIE{ fOutputTensorNames = other.fOutputTensorNames; fOperators = std::move(other.fOperators); fInitializedTensors = std::move(other.fInitializedTensors); + fIntermediateTensorInfos = std::move(other.fIntermediateTensorInfos); fName = other.fName; fFileName = other.fFileName; fParseTime = other.fParseTime; @@ -29,6 +30,7 @@ namespace SOFIE{ fOutputTensorNames = other.fOutputTensorNames; fOperators = std::move(other.fOperators); fInitializedTensors = std::move(other.fInitializedTensors); + fIntermediateTensorInfos = std::move(other.fIntermediateTensorInfos); fName = other.fName; fFileName = other.fFileName; fParseTime = other.fParseTime; @@ -218,7 +220,8 @@ namespace SOFIE{ } } - if (fOutputTensorNames.size() == 1){ + size_t outputSize = fOutputTensorNames.size(); + if (outputSize == 1) { auto f = fIntermediateTensorInfos.find(fOutputTensorNames[0]); if (f == fIntermediateTensorInfos.end()){ throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[0] + " not found when trying to get its info"); @@ -227,9 +230,26 @@ namespace SOFIE{ fGC += "std::vector "; } } - }else{ - std::cout << fOutputTensorNames.size() << std::endl; - throw std::runtime_error("TMVA-SOFIE: More than 1 output tensor is not yet supported"); + } else { + std::vector outputTensorsTypes(outputSize); + for (size_t i = 0; i < outputSize; i++) { + auto f = fIntermediateTensorInfos.find(fOutputTensorNames[i]); + if (f == fIntermediateTensorInfos.end()) { + throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[i] + + " not found when trying to get its info"); + } else { + outputTensorsTypes[i] = f->second.type; + } + } + ETensorType outputType = outputTensorsTypes[0]; + for (size_t i = 0; i < outputSize; i++) { + if (outputTensorsTypes[i] != outputType) { + throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[i] + " is of different type."); + } + } + if (outputType == ETensorType::FLOAT) { + fGC += "std::vector> "; + } } fGC += "infer("; @@ -248,11 +268,32 @@ namespace SOFIE{ for (size_t id = 0; id < fOperators.size() ; id++){ fGC+= (fOperators[id]->Generate(std::to_string(id))); } - if (fOutputTensorNames.size() == 1){ + if (outputSize == 1) { fGC += "\tstd::vector ret (tensor_" + fOutputTensorNames[0] + ", tensor_" + fOutputTensorNames[0] + " + sizeof(tensor_" + fOutputTensorNames[0] + ") / sizeof(tensor_" + fOutputTensorNames[0] + "[0]));\n"; - fGC += "\treturn ret;\n"; + } else { + for (size_t i = 0; i < outputSize; i++) { + if (!fOutputTensorNames[i].empty()) { + fGC += "\tstd::vector ret_"; + fGC += std::to_string(i); + fGC += " (tensor_" + fOutputTensorNames[i] + ", tensor_" + fOutputTensorNames[i] + " + sizeof(tensor_" + fOutputTensorNames[i] + ") / sizeof(tensor_" + fOutputTensorNames[i] + "[0]));\n"; + } + } + fGC += "\tstd::vector> ret({"; + for (size_t i = 0; i < outputSize; i++) { + if (fOutputTensorNames[i].empty()) { + fGC += "{}"; + } else { + fGC += "ret_"; + fGC += std::to_string(i); + } + if (i < outputSize - 1) { + fGC += ","; + } + } + fGC += "});\n"; } + fGC += "\treturn ret;\n"; fGC += "}\n"; fGC += ("} //TMVA_SOFIE_" + fName + "\n"); }