Skip to content

Commit

Permalink
Add support for operators with multiple outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
axmat authored and lmoneta committed Aug 24, 2021
1 parent 59e754a commit 5d536b2
Showing 1 changed file with 47 additions and 6 deletions.
53 changes: 47 additions & 6 deletions tmva/sofie/src/RModel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace SOFIE{
fOutputTensorNames = other.fOutputTensorNames;
fOperators = std::move(other.fOperators);
fInitializedTensors = std::move(other.fInitializedTensors);
fIntermediateTensorInfos = std::move(other.fIntermediateTensorInfos);
fName = other.fName;
fFileName = other.fFileName;
fParseTime = other.fParseTime;
Expand All @@ -29,6 +30,7 @@ namespace SOFIE{
fOutputTensorNames = other.fOutputTensorNames;
fOperators = std::move(other.fOperators);
fInitializedTensors = std::move(other.fInitializedTensors);
fIntermediateTensorInfos = std::move(other.fIntermediateTensorInfos);
fName = other.fName;
fFileName = other.fFileName;
fParseTime = other.fParseTime;
Expand Down Expand Up @@ -218,7 +220,8 @@ namespace SOFIE{
}
}

if (fOutputTensorNames.size() == 1){
size_t outputSize = fOutputTensorNames.size();
if (outputSize == 1) {
auto f = fIntermediateTensorInfos.find(fOutputTensorNames[0]);
if (f == fIntermediateTensorInfos.end()){
throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[0] + " not found when trying to get its info");
Expand All @@ -227,9 +230,26 @@ namespace SOFIE{
fGC += "std::vector<float> ";
}
}
}else{
std::cout << fOutputTensorNames.size() << std::endl;
throw std::runtime_error("TMVA-SOFIE: More than 1 output tensor is not yet supported");
} else {
std::vector<ETensorType> outputTensorsTypes(outputSize);
for (size_t i = 0; i < outputSize; i++) {
auto f = fIntermediateTensorInfos.find(fOutputTensorNames[i]);
if (f == fIntermediateTensorInfos.end()) {
throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[i]
+ " not found when trying to get its info");
} else {
outputTensorsTypes[i] = f->second.type;
}
}
ETensorType outputType = outputTensorsTypes[0];
for (size_t i = 0; i < outputSize; i++) {
if (outputTensorsTypes[i] != outputType) {
throw std::runtime_error("TMVA-SOFIE: output tensor " + fOutputTensorNames[i] + " is of different type.");
}
}
if (outputType == ETensorType::FLOAT) {
fGC += "std::vector<std::vector<float>> ";
}
}

fGC += "infer(";
Expand All @@ -248,11 +268,32 @@ namespace SOFIE{
for (size_t id = 0; id < fOperators.size() ; id++){
fGC+= (fOperators[id]->Generate(std::to_string(id)));
}
if (fOutputTensorNames.size() == 1){
if (outputSize == 1) {
fGC += "\tstd::vector<float> ret (tensor_" + fOutputTensorNames[0] + ", tensor_" + fOutputTensorNames[0] + " + sizeof(tensor_" +
fOutputTensorNames[0] + ") / sizeof(tensor_" + fOutputTensorNames[0] + "[0]));\n";
fGC += "\treturn ret;\n";
} else {
for (size_t i = 0; i < outputSize; i++) {
if (!fOutputTensorNames[i].empty()) {
fGC += "\tstd::vector<float> ret_";
fGC += std::to_string(i);
fGC += " (tensor_" + fOutputTensorNames[i] + ", tensor_" + fOutputTensorNames[i] + " + sizeof(tensor_" + fOutputTensorNames[i] + ") / sizeof(tensor_" + fOutputTensorNames[i] + "[0]));\n";
}
}
fGC += "\tstd::vector<std::vector<float>> ret({";
for (size_t i = 0; i < outputSize; i++) {
if (fOutputTensorNames[i].empty()) {
fGC += "{}";
} else {
fGC += "ret_";
fGC += std::to_string(i);
}
if (i < outputSize - 1) {
fGC += ",";
}
}
fGC += "});\n";
}
fGC += "\treturn ret;\n";
fGC += "}\n";
fGC += ("} //TMVA_SOFIE_" + fName + "\n");
}
Expand Down

0 comments on commit 5d536b2

Please sign in to comment.