Skip to content

Commit

Permalink
Add scalar support in ORT backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Sep 21, 2023
1 parent 4b88138 commit a855fa5
Showing 1 changed file with 108 additions and 20 deletions.
128 changes: 108 additions & 20 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,9 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
triton::common::TritonJson::Value reshape_dims(
ModelConfig(), triton::common::TritonJson::ValueType::ARRAY);
RETURN_IF_ERROR(reshape.Add("shape", std::move(reshape_dims)));
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
if (MaxBatchSize() > 0) {
RETURN_IF_ERROR(io.Add("reshape", std::move(reshape)));
}
}
RETURN_IF_ERROR(io.Add("dims", std::move(dims)));
RETURN_IF_ERROR(ios.Append(std::move(io)));
Expand Down Expand Up @@ -998,6 +1000,12 @@ class ModelInstanceState : public BackendModelInstance {
// map of output name -> tensor info
OnnxTensorInfoMap output_tensor_infos_;

// map of input name -> tensor info
OnnxTensorInfoMap input_tensor_infos_;

// A map from scalar output tensors to the dimension specified in model config
std::unordered_map<std::string, std::vector<int64_t>> scalar_outputs_;

// Onnx Runtime variables that will be reset and used for every run
// on this instance.
std::vector<OrtValue*> input_tensors_;
Expand Down Expand Up @@ -1313,9 +1321,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
{
std::set<std::string> input_tensor_names;
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));

OnnxTensorInfoMap input_tensor_infos;
RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos));
RETURN_IF_ERROR(
InputInfos(session_, default_allocator_, input_tensor_infos_));

std::set<std::string> overridable_initializer_tensor_names;
RETURN_IF_ERROR(OverridableInitializerNames(
Expand All @@ -1325,12 +1332,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
RETURN_IF_ERROR(OverridableInitializerInfos(
session_, default_allocator_, overridable_initializer_tensor_infos));

if (input_tensor_infos.size() != expected_input_cnt) {
if (input_tensor_infos_.size() != expected_input_cnt) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("unable to load model '") + model_state_->Name() +
"', configuration expects " + std::to_string(expected_input_cnt) +
" inputs, model provides " + std::to_string(input_tensor_infos.size()))
" inputs, model provides " +
std::to_string(input_tensor_infos_.size()))
.c_str());
}

Expand All @@ -1357,8 +1365,9 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)

const auto& tensor_names =
io_optional ? overridable_initializer_tensor_names : input_tensor_names;
const auto& tensor_infos =
io_optional ? overridable_initializer_tensor_infos : input_tensor_infos;
const auto& tensor_infos = io_optional
? overridable_initializer_tensor_infos
: input_tensor_infos_;
auto iit = tensor_infos.find(io_name);
if (iit == tensor_infos.end()) {
RETURN_IF_ERROR(CheckAllowedModelInput(io, tensor_names));
Expand Down Expand Up @@ -1419,9 +1428,30 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
.c_str());
}
} else {
RETURN_IF_ERROR(CompareDimsSupported(
model_state_->Name(), io_name, iit->second.dims_, dims,
model_state_->MaxBatchSize(), false /* compare_exact */));
if (model_state_->MaxBatchSize() != 0 || iit->second.dims_.size() > 0) {
RETURN_IF_ERROR(CompareDimsSupported(
model_state_->Name(), io_name, iit->second.dims_, dims,
model_state_->MaxBatchSize(), false /* compare_exact */));
} else {
// if max_batch_size == 0 and is a scalar tensor all the
// dimensions specified must be equal to 1
for (auto& dim : dims) {
if (dim != 1) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("unable to load model '") + model_state_->Name() +
"', scalar tensor '" + io_name +
"', should only provide 1 in the model configuration when the "
"model doesn't support batching. Model configuration "
"provided: " +
ShapeToString(dims) + ".")
.c_str());
}
}

// store the dimension for reference.
scalar_inputs_[io_name] = dims;
}
}
}

Expand Down Expand Up @@ -1482,9 +1512,30 @@ ModelInstanceState::ValidateOutputs()

// The batch output shape doesn't necessarily match the model
if (model_state_->FindBatchOutput(io_name) == nullptr) {
RETURN_IF_ERROR(CompareDimsSupported(
model_state_->Name(), io_name, iit->second.dims_, dims,
model_state_->MaxBatchSize(), true /* compare_exact */));
// if max_batch_size == 0 and is a scalar tensor all the
// dimensions specified must be equal to 1
if (model_state_->MaxBatchSize() > 0 || iit->second.dims_.size() > 0) {
RETURN_IF_ERROR(CompareDimsSupported(
model_state_->Name(), io_name, iit->second.dims_, dims,
model_state_->MaxBatchSize(), true /* compare_exact */));
} else {
for (auto& dim : dims) {
if (dim != 1) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("unable to load model '") + model_state_->Name() +
"', scalar tensor '" + io_name +
"', should only provide 1 in the model configuration when the "
"model doesn't support batching. Model configuration "
"provided: " +
ShapeToString(dims) + ".")
.c_str());
}
}

// store the dimension for reference.
scalar_outputs_[io_name] = dims;
}
}
}

Expand Down Expand Up @@ -1900,13 +1951,34 @@ ModelInstanceState::SetInputTensors(
input_name, nullptr, 0, allowed_input_types, &input_buffer,
&batchn_byte_size, &memory_type, &memory_type_id));

auto iti = input_tensor_infos_.find(input_name);
if (iti == input_tensor_infos_.end()) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
std::string(
"Failed to retrieve the ONNX input tensor info from '") +
input_name + "'.")
.c_str());
}

// Create ORT Tensor
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
: cpu_allocator_info_,
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
&input_tensors_.back()));
if (iti->second.dims_.size() == 0) {
// scalar tensor
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
: cpu_allocator_info_,
(void*)input_buffer, batchn_byte_size, nullptr /* scalar */,
0 /* number of dims */, ConvertToOnnxDataType(input_datatype),
&input_tensors_.back()));
} else {
RETURN_IF_ORT_ERROR(ort_api->CreateTensorWithDataAsOrtValue(
memory_type == TRITONSERVER_MEMORY_GPU ? cuda_allocator_info_
: cpu_allocator_info_,
(void*)input_buffer, batchn_byte_size, batchn_shape.data(),
batchn_shape.size(), ConvertToOnnxDataType(input_datatype),
&input_tensors_.back()));
}
RETURN_IF_ORT_ERROR(
ort_api->BindInput(io_binding_, input_name, input_tensors_.back()));
} else {
Expand Down Expand Up @@ -2283,6 +2355,22 @@ ModelInstanceState::ReadOutputTensors(
batchn_shape, dtype, output_tensor, &output_buffer, string_buffers,
offsets));

// If the number of dimensions is equal to zero, it means that it is a
// scalar and it would use the dimensions specified in the mdel
// configuration.
if (batchn_shape.size() == 0) {
auto scalar_output_dims_it = scalar_outputs_.find(name);
if (scalar_output_dims_it == scalar_outputs_.end()) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
"Failed to find the scalar output dimension for " + name +
" in the model configuration.")
.c_str());
}
batchn_shape = scalar_output_dims_it->second;
}

if (output_tensor_pair.first != -1) {
if (dtype == TRITONSERVER_TYPE_BYTES) {
auto content = string_buffers.back().data();
Expand Down

0 comments on commit a855fa5

Please sign in to comment.