From 814fd590e34358924b18c531f9188aab5f67f808 Mon Sep 17 00:00:00 2001 From: gordinmitya Date: Thu, 24 Sep 2020 18:49:15 +0300 Subject: [PATCH] add onnx-runtime Former-commit-id: c09e21e7165676d9deee4e59ffe2aaaf4128670e --- .gitattributes | 2 + app/build.gradle | 1 + .../java/ru/gordinmitya/dnnbenchmark/App.kt | 14 +- onnxruntime/.gitignore | 1 + onnxruntime/build.gradle | 47 + onnxruntime/consumer-rules.pro | 0 onnxruntime/proguard-rules.pro | 21 + onnxruntime/src/main/AndroidManifest.xml | 5 + .../src/main/assets/mobilenet_v2.onnx | 0 onnxruntime/src/main/cpp/CMakeLists.txt | 29 + onnxruntime/src/main/cpp/imageresizer.h | 105 ++ .../onnxruntime/core/common/code_location.h | 57 + .../includes/onnxruntime/core/common/common.h | 245 ++++ .../core/common/const_pointer_container.h | 85 ++ .../core/common/eigen_common_wrapper.h | 41 + .../onnxruntime/core/common/exceptions.h | 71 + .../onnxruntime/core/common/logging/capture.h | 115 ++ .../onnxruntime/core/common/logging/isink.h | 41 + .../onnxruntime/core/common/logging/logging.h | 347 +++++ .../onnxruntime/core/common/logging/macros.h | 209 +++ .../core/common/logging/severity.h | 22 + .../onnxruntime/core/common/make_unique.h | 148 ++ .../onnxruntime/core/common/optional.h | 42 + .../includes/onnxruntime/core/common/status.h | 191 +++ .../onnxruntime/core/framework/alloc_kind.h | 34 + .../onnxruntime/core/framework/allocator.h | 304 +++++ .../core/framework/customregistry.h | 49 + .../onnxruntime/core/framework/data_types.h | 1013 ++++++++++++++ .../core/framework/data_types_internal.h | 501 +++++++ .../onnxruntime/core/framework/endian.h | 27 + .../core/framework/execution_provider.h | 183 +++ .../onnxruntime/core/framework/fence.h | 57 + .../core/framework/framework_common.h | 22 + .../onnxruntime/core/framework/func_api.h | 27 + .../core/framework/kernel_def_builder.h | 284 ++++ .../core/framework/kernel_registry.h | 84 ++ .../onnxruntime/core/framework/ml_value.h | 122 ++ .../onnxruntime/core/framework/op_kernel.h | 407 ++++++ .../core/framework/op_kernel_info.h | 68 + .../core/framework/op_node_proto_helper.h | 140 ++ .../onnxruntime/core/framework/run_options.h | 48 + .../core/framework/sparse_tensor.h | 74 + .../onnxruntime/core/framework/tensor.h | 241 ++++ .../onnxruntime/core/framework/tensor_shape.h | 143 ++ .../onnxruntime/core/graph/basic_types.h | 39 + .../onnxruntime/core/graph/constants.h | 83 ++ .../onnxruntime/core/graph/function.h | 42 + .../includes/onnxruntime/core/graph/graph.h | 1210 +++++++++++++++++ .../onnxruntime/core/graph/graph_nodes.h | 149 ++ .../onnxruntime/core/graph/graph_viewer.h | 140 ++ .../core/graph/indexed_sub_graph.h | 59 + .../onnxruntime/core/graph/node_arg.h | 108 ++ .../onnxruntime/core/graph/onnx_protobuf.h | 42 + .../onnxruntime/core/graph/schema_registry.h | 152 +++ .../core/optimizer/graph_transformer.h | 70 + .../core/optimizer/graph_transformer_level.h | 19 + .../core/optimizer/graph_transformer_utils.h | 39 + .../onnxruntime/core/optimizer/rewrite_rule.h | 89 ++ .../optimizer/rule_based_graph_transformer.h | 82 ++ .../onnxruntime/core/platform/Barrier.h | 69 + .../platform/EigenNonBlockingThreadPool.h | 1008 ++++++++++++++ .../onnxruntime/core/platform/ort_mutex.h | 193 +++ .../onnxruntime/core/platform/threadpool.h | 386 ++++++ .../onnxruntime/core/platform/tracing.h | 9 + .../platform/windows/TraceLoggingConfig.h | 81 ++ .../core/platform/windows/readme.txt | 2 + .../core/providers/acl/acl_provider_factory.h | 18 + .../providers/armnn/armnn_provider_factory.h | 18 + .../core/providers/cpu/cpu_provider_factory.h | 18 + .../providers/cuda/cuda_provider_factory.h | 17 + .../core/providers/dml/dml_provider_factory.h | 47 + .../providers/dnnl/dnnl_provider_factory.h | 17 + .../migraphx/migraphx_provider_factory.h | 15 + .../ngraph/ngraph_provider_factory.h | 17 + .../providers/nnapi/nnapi_provider_factory.h | 15 + .../nuphar/nuphar_provider_factory.h | 17 + .../openvino/openvino_provider_factory.h | 18 + .../onnxruntime/core/providers/providers.h | 13 + .../providers/rknpu/rknpu_provider_factory.h | 14 + .../tensorrt/tensorrt_provider_factory.h | 15 + .../vitisai/vitisai_provider_factory.h | 17 + .../providers/winml/winml_provider_factory.h | 9 + .../core/session/automl_data_containers.h | 30 + .../onnxruntime/core/session/environment.h | 70 + .../experimental_onnxruntime_cxx_api.h | 67 + .../experimental_onnxruntime_cxx_inline.h | 109 ++ .../core/session/onnxruntime_c_api.h | 883 ++++++++++++ .../core/session/onnxruntime_cxx_api.h | 381 ++++++ .../core/session/onnxruntime_cxx_inline.h | 630 +++++++++ onnxruntime/src/main/cpp/logs.h | 26 + onnxruntime/src/main/cpp/onnxjni.cpp | 78 ++ .../src/main/cpp/onnxruntime_inference.cpp | 137 ++ .../src/main/cpp/onnxruntime_inference.h | 56 + onnxruntime/src/main/cpp/postprocess.h | 79 ++ onnxruntime/src/main/cpp/preprocess.h | 69 + onnxruntime/src/main/cpp/utils.h | 32 + .../gordinmitya/onnxruntime/ConvertedModel.kt | 30 + .../gordinmitya/onnxruntime/ONNXClassifier.kt | 39 + .../gordinmitya/onnxruntime/ONNXFramework.kt | 30 + .../onnxruntime/ONNXInfereceType.kt | 7 + .../gordinmitya/onnxruntime/ONNXNative.java | 58 + .../main/jniLibs/arm64-v8a/libonnxruntime.so | 3 + .../jniLibs/armeabi-v7a/libonnxruntime.so | 3 + opencv/src/main/assets/.gitattributes | 1 - settings.gradle | 2 +- 105 files changed, 12835 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/.gitignore create mode 100644 onnxruntime/build.gradle create mode 100644 onnxruntime/consumer-rules.pro create mode 100644 onnxruntime/proguard-rules.pro create mode 100644 onnxruntime/src/main/AndroidManifest.xml rename {opencv => onnxruntime}/src/main/assets/mobilenet_v2.onnx (100%) create mode 100644 onnxruntime/src/main/cpp/CMakeLists.txt create mode 100644 onnxruntime/src/main/cpp/imageresizer.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/code_location.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/common.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/const_pointer_container.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/eigen_common_wrapper.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/exceptions.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/capture.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/isink.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/logging.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/macros.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/severity.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/make_unique.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/optional.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/common/status.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/alloc_kind.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/allocator.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/customregistry.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types_internal.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/endian.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/execution_provider.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/fence.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/framework_common.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/func_api.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/kernel_def_builder.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/kernel_registry.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/ml_value.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/op_kernel.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/op_kernel_info.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/op_node_proto_helper.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/run_options.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/sparse_tensor.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/tensor.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/tensor_shape.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/basic_types.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/constants.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/function.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/graph.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/graph_nodes.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/graph_viewer.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/indexed_sub_graph.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/node_arg.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/onnx_protobuf.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/graph/schema_registry.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/optimizer/graph_transformer.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/optimizer/graph_transformer_level.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/optimizer/graph_transformer_utils.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/optimizer/rewrite_rule.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/optimizer/rule_based_graph_transformer.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/platform/Barrier.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/platform/EigenNonBlockingThreadPool.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/platform/ort_mutex.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/platform/threadpool.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/platform/tracing.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/platform/windows/TraceLoggingConfig.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/platform/windows/readme.txt create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/acl/acl_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/armnn/armnn_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/cpu/cpu_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/cuda/cuda_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/dml/dml_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/dnnl/dnnl_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/ngraph/ngraph_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/nuphar/nuphar_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/openvino/openvino_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/providers.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/rknpu/rknpu_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/vitisai/vitisai_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/providers/winml/winml_provider_factory.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/session/automl_data_containers.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/session/environment.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/session/experimental_onnxruntime_cxx_api.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/session/experimental_onnxruntime_cxx_inline.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/session/onnxruntime_c_api.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/session/onnxruntime_cxx_api.h create mode 100644 onnxruntime/src/main/cpp/includes/onnxruntime/core/session/onnxruntime_cxx_inline.h create mode 100644 onnxruntime/src/main/cpp/logs.h create mode 100644 onnxruntime/src/main/cpp/onnxjni.cpp create mode 100644 onnxruntime/src/main/cpp/onnxruntime_inference.cpp create mode 100644 onnxruntime/src/main/cpp/onnxruntime_inference.h create mode 100644 onnxruntime/src/main/cpp/postprocess.h create mode 100644 onnxruntime/src/main/cpp/preprocess.h create mode 100644 onnxruntime/src/main/cpp/utils.h create mode 100644 onnxruntime/src/main/java/ru/gordinmitya/onnxruntime/ConvertedModel.kt create mode 100644 onnxruntime/src/main/java/ru/gordinmitya/onnxruntime/ONNXClassifier.kt create mode 100644 onnxruntime/src/main/java/ru/gordinmitya/onnxruntime/ONNXFramework.kt create mode 100644 onnxruntime/src/main/java/ru/gordinmitya/onnxruntime/ONNXInfereceType.kt create mode 100644 onnxruntime/src/main/java/ru/gordinmitya/onnxruntime/ONNXNative.java create mode 100755 onnxruntime/src/main/jniLibs/arm64-v8a/libonnxruntime.so create mode 100755 onnxruntime/src/main/jniLibs/armeabi-v7a/libonnxruntime.so delete mode 100644 opencv/src/main/assets/.gitattributes diff --git a/.gitattributes b/.gitattributes index e09eab1..946ebeb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,3 +13,5 @@ mace/src/main/assets/mace/*.data filter=lfs diff=lfs merge=lfs -text mace/src/main/assets/mace/*.pb filter=lfs diff=lfs merge=lfs -text mace/src/main/jniLibs/*/*.so filter=lfs diff=lfs merge=lfs -text snpe/src/main/assets/snpe/*.dlc filter=lfs diff=lfs merge=lfs -text +onnxruntime/src/main/assets/*.onnx filter=lfs diff=lfs merge=lfs -text +onnxruntime/src/main/jniLibs/*/*.so filter=lfs diff=lfs merge=lfs -text diff --git a/app/build.gradle b/app/build.gradle index f9acc83..d516cfa 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -105,4 +105,5 @@ dependencies { implementation project(path: ':opencv') implementation project(path: ':mace') implementation project(path: ':snpe') + implementation project(path: ':onnxruntime') } diff --git a/app/src/main/java/ru/gordinmitya/dnnbenchmark/App.kt b/app/src/main/java/ru/gordinmitya/dnnbenchmark/App.kt index aab445e..151b057 100644 --- a/app/src/main/java/ru/gordinmitya/dnnbenchmark/App.kt +++ b/app/src/main/java/ru/gordinmitya/dnnbenchmark/App.kt @@ -8,6 +8,7 @@ import ru.gordinmitya.common.segmentation.DeepLabModel import ru.gordinmitya.mace.MACEFramework import ru.gordinmitya.mnn.MNNFramework import ru.gordinmitya.ncnn.NCNNFramework +import ru.gordinmitya.onnxruntime.ONNXFramework import ru.gordinmitya.opencv.OpenCVFramework import ru.gordinmitya.pytorch.PytorchFramework import ru.gordinmitya.snpe.SNPEFramework @@ -25,21 +26,22 @@ class App : Application() { super.onCreate() instance = this frameworks = listOf( - MNNFramework(), - TFLiteFramework(), + ONNXFramework(), +// MNNFramework(), +// TFLiteFramework(), // MACEFramework(), // SNPEFramework(), - OpenCVFramework(), +// OpenCVFramework(), // TFMobileFramework(), - PytorchFramework(), - NCNNFramework() +// PytorchFramework(), +// NCNNFramework() ) } @Suppress("SimplifyBooleanWithConstants") companion object { val DEBUG = true && BuildConfig.DEBUG - val USE_PROCESS = true || !DEBUG + val USE_PROCESS = false || !DEBUG lateinit var instance: App } diff --git a/onnxruntime/.gitignore b/onnxruntime/.gitignore new file mode 100644 index 0000000..42afabf --- /dev/null +++ b/onnxruntime/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/onnxruntime/build.gradle b/onnxruntime/build.gradle new file mode 100644 index 0000000..cb62f11 --- /dev/null +++ b/onnxruntime/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' +apply plugin: 'kotlin-android' +apply plugin: 'kotlin-android-extensions' + +android { + compileSdkVersion rootProject.compileSdkVersion + + defaultConfig { + minSdkVersion rootProject.minSdkVersion + targetSdkVersion rootProject.targetSdkVersion + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles "consumer-rules.pro" + + ndk { + abiFilters 'arm64-v8a', 'armeabi-v7a' + } + + externalNativeBuild { + cmake { + arguments "-DANDROID_ARM_NEON=TRUE", "-DANDROID_PLATFORM=android-24", "-DANDROID_STL=c++_shared" + } + } + } + + externalNativeBuild { + cmake { + version "3.10.2" + path "src/main/cpp/CMakeLists.txt" + } + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } +} + +dependencies { + implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version" + + implementation project(path: ':common') +} \ No newline at end of file diff --git a/onnxruntime/consumer-rules.pro b/onnxruntime/consumer-rules.pro new file mode 100644 index 0000000..e69de29 diff --git a/onnxruntime/proguard-rules.pro b/onnxruntime/proguard-rules.pro new file mode 100644 index 0000000..481bb43 --- /dev/null +++ b/onnxruntime/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/onnxruntime/src/main/AndroidManifest.xml b/onnxruntime/src/main/AndroidManifest.xml new file mode 100644 index 0000000..19a34aa --- /dev/null +++ b/onnxruntime/src/main/AndroidManifest.xml @@ -0,0 +1,5 @@ + + + / + \ No newline at end of file diff --git a/opencv/src/main/assets/mobilenet_v2.onnx b/onnxruntime/src/main/assets/mobilenet_v2.onnx similarity index 100% rename from opencv/src/main/assets/mobilenet_v2.onnx rename to onnxruntime/src/main/assets/mobilenet_v2.onnx diff --git a/onnxruntime/src/main/cpp/CMakeLists.txt b/onnxruntime/src/main/cpp/CMakeLists.txt new file mode 100644 index 0000000..6bd9074 --- /dev/null +++ b/onnxruntime/src/main/cpp/CMakeLists.txt @@ -0,0 +1,29 @@ +cmake_minimum_required(VERSION 3.4.1) + +set(lib_DIR "${CMAKE_SOURCE_DIR}/../jniLibs") +include_directories(${CMAKE_SOURCE_DIR}/includes) +include_directories(${CMAKE_SOURCE_DIR}/includes/onnxruntime/core/session) + +add_library( + onnxcore + SHARED + onnxjni.cpp onnxruntime_inference.cpp +) + +add_library(libonnxruntime STATIC IMPORTED) +set_target_properties( + libonnxruntime + PROPERTIES IMPORTED_LOCATION + ${lib_DIR}/${ANDROID_ABI}/libonnxruntime.so +) + +find_library(log-lib log) +find_library(jnigraphics-lib jnigraphics) + +target_link_libraries( + onnxcore + libonnxruntime + + ${log-lib} + ${jnigraphics-lib} +) diff --git a/onnxruntime/src/main/cpp/imageresizer.h b/onnxruntime/src/main/cpp/imageresizer.h new file mode 100644 index 0000000..48785e5 --- /dev/null +++ b/onnxruntime/src/main/cpp/imageresizer.h @@ -0,0 +1,105 @@ +// +// Created by rohith on 7/24/18. +// + +#ifndef TFLITENATIVE_IMAGERESIZER_H +#define TFLITENATIVE_IMAGERESIZER_H + +#include "logs.h" + +#include + + +/** + *center crop and resize the image + * + */ +void cropResizeImage(unsigned char *inputpixel, int w_image, int h_image ,int ch_image, uint8_t* output, int w_network, int h_network) +{ + LOGV("in width %d, in height %d",w_image, h_image); + + if(h_image >= w_image)//portrait mode + { + LOGE("portrait"); + const int skipheight = (h_image - w_image) / 2; + LOGE("skipheight %d", skipheight); + + int croped_width = w_image; + int croped_height = w_image; + + unsigned char* start_address = inputpixel + (skipheight*croped_width)*ch_image; + uint8_t* rgb_in = output; + + for(int y =0;y +#include +#include + +namespace onnxruntime { +/** + CodeLocation captures information on where in the source code a message came from. +*/ +struct CodeLocation { + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + */ + CodeLocation(const char* file_path, const int line, const char* func) + : file_and_path{file_path}, line_num{line}, function{func} { + } + + /** + @param file_path Usually the value of __FILE__ + @param line Usually the value of __LINE__ + @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ + @param stacktrace Stacktrace from source of message. + */ + CodeLocation(const char* file_path, const int line, const char* func, const std::vector& stacktrace) + : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) { + } + + std::string FileNoPath() const { + // assuming we always have work to do, so not trying to avoid creating a new string if + // no path was removed. + return file_and_path.substr(file_and_path.find_last_of("/\\") + 1); + } + + enum Format { + kFilename, + kFilenameAndPath + }; + + std::string ToString(Format format = Format::kFilename) const { + std::ostringstream out; + out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function; + return out.str(); + } + + const std::string file_and_path; + const int line_num; + const std::string function; + const std::vector stacktrace; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/common.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/common.h new file mode 100644 index 0000000..f6ff7d3 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/common.h @@ -0,0 +1,245 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/common/code_location.h" +#include "core/common/exceptions.h" +#include "core/common/make_unique.h" +#include "core/common/status.h" + +#ifdef USE_MIMALLOC_ARENA_ALLOCATOR +#include +#endif + +namespace onnxruntime { + +using TimePoint = std::chrono::high_resolution_clock::time_point; + +// Using statements for common classes that we refer to in ONNXRuntime very often. +// TODO(Task:137) Remove 'using' statements from header files +using common::Status; + +#ifdef _WIN32 +#define ORT_UNUSED_PARAMETER(x) (x) +#else +#define ORT_UNUSED_PARAMETER(x) (void)(x) +#endif + +#ifndef ORT_HAVE_ATTRIBUTE +#ifdef __has_attribute +#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) +#else +#define ORT_HAVE_ATTRIBUTE(x) 0 +#endif +#endif + +// ORT_ATTRIBUTE_UNUSED +// +// Prevents the compiler from complaining about or optimizing away variables +// that appear unused on Linux +#if ORT_HAVE_ATTRIBUTE(unused) || (defined(__GNUC__) && !defined(__clang__)) +#undef ORT_ATTRIBUTE_UNUSED +#define ORT_ATTRIBUTE_UNUSED __attribute__((__unused__)) +#else +#define ORT_ATTRIBUTE_UNUSED +#endif + +// macro to explicitly ignore the return value from a function call so Code Analysis doesn't complain +#define ORT_IGNORE_RETURN_VALUE(fn) \ + static_cast(fn) + +std::vector GetStackTrace(); +// these is a helper function that gets defined by platform/Telemetry +void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, + const char* function, uint32_t line); + +// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER +// so we only define it as one for MSVC +#if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) +#define __PRETTY_FUNCTION__ __FUNCTION__ +#endif + +// Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__ +#define ORT_WHERE \ + ::onnxruntime::CodeLocation(__FILE__, __LINE__, __FUNCTION__) + +#define ORT_WHERE_WITH_STACK \ + ::onnxruntime::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__, ::onnxruntime::GetStackTrace()) + +// Throw an exception with optional message. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_THROW(...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) + +// Just in order to mark things as not implemented. Do not use in final code. +#define ORT_NOT_IMPLEMENTED(...) \ + throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. +// NOTE: The arguments get streamed into a string via ostringstream::operator<< +// DO NOT use a printf format string, as that will not work as you expect. +#define ORT_ENFORCE(condition, ...) \ + if (!(condition)) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \ + ::onnxruntime::MakeString(__VA_ARGS__)) + +#define ORT_MAKE_STATUS(category, code, ...) \ + ::onnxruntime::common::Status(::onnxruntime::common::category, \ + ::onnxruntime::common::code, \ + ::onnxruntime::MakeString(__VA_ARGS__)) + +// Check condition. if met, return status. +#define ORT_RETURN_IF(condition, ...) \ + if (condition) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, \ + "Satisfied, but should not be: " #condition "\n", \ + ORT_WHERE.ToString(), ::onnxruntime::MakeString(__VA_ARGS__)); \ + } + +// Check condition. if not met, return status. +#define ORT_RETURN_IF_NOT(condition, ...) \ + if (!(condition)) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not satisfied: " #condition "\n", \ + ORT_WHERE.ToString(), ::onnxruntime::MakeString(__VA_ARGS__)); \ + } + +// Macros to disable the copy and/or move ctor and assignment methods +// These are usually placed in the private: declarations for a class. + +#define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete + +#define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete + +#define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \ + ORT_DISALLOW_COPY(TypeName); \ + ORT_DISALLOW_ASSIGNMENT(TypeName) + +#define ORT_DISALLOW_MOVE(TypeName) \ + TypeName(TypeName&&) = delete; \ + TypeName& operator=(TypeName&&) = delete + +#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \ + ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \ + ORT_DISALLOW_MOVE(TypeName) + +#define ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + ::onnxruntime::LogRuntimeError(session_id, _status, __FILE__, __FUNCTION__, __LINE__); \ + return _status; \ + } \ + } while (0) + +#define ORT_RETURN_IF_ERROR_SESSIONID_(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id_) +#define ORT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, 0) + +#define ORT_THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + ::onnxruntime::LogRuntimeError(0, _status, __FILE__, __FUNCTION__, __LINE__); \ + ORT_THROW(_status); \ + } \ + } while (0) + +// use this macro when cannot early return +#define ORT_CHECK_AND_SET_RETVAL(expr) \ + do { \ + if (retval.IsOK()) { \ + retval = (expr); \ + } \ + } while (0) + +// C++ Core Guideline check suppression. +#if defined(_MSC_VER) && !defined(__NVCC__) +#define GSL_SUPPRESS(tag) [[gsl::suppress(tag)]] +#else +#define GSL_SUPPRESS(tag) +#endif + +inline void MakeStringInternal(std::ostringstream& /*ss*/) noexcept { +} + +template +inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept { + ::onnxruntime::MakeStringInternal(ss, t); + ::onnxruntime::MakeStringInternal(ss, args...); +} + +template +std::string MakeString(const Args&... args) { + std::ostringstream ss; + ::onnxruntime::MakeStringInternal(ss, args...); + return std::string(ss.str()); +} + +// Specializations for already-a-string types. +template <> +inline std::string MakeString(const std::string& str) { + return str; +} +inline std::string MakeString(const char* p_str) { + return p_str; +} + +inline long long TimeDiffMicroSeconds(TimePoint start_time) { + auto end_time = std::chrono::high_resolution_clock::now(); + return std::chrono::duration_cast(end_time - start_time).count(); +} + +inline long long TimeDiffMicroSeconds(TimePoint start_time, TimePoint end_time) { + return std::chrono::duration_cast(end_time - start_time).count(); +} + +struct null_type {}; +inline std::string ToMBString(const std::string& s) { return s; } +#ifdef _WIN32 +/** + * Convert a wide character string into a narrow one, with local ANSI code page(like CP936) + * DO NOT assume the result string is encoded in UTF-8 + */ +std::string ToMBString(const std::wstring& s); + +std::wstring ToWideString(const std::string& s); +inline std::wstring ToWideString(const std::wstring& s) { return s; } +#else +inline std::string ToWideString(const std::string& s) { return s; } +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/const_pointer_container.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/const_pointer_container.h new file mode 100644 index 0000000..1d821ba --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/const_pointer_container.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +/** + Container has T* entries. e.g. std::vector, and this class provides const access to those + via iterators and direct access, as the standard behavior only makes the pointer constant, + and not what is pointed too. i.e. you get a const pointer to T not a pointer to const T without this wrapper. + See https://stackoverflow.com/questions/8017036/understanding-const-iterator-with-pointers +*/ +template +class ConstPointerContainer { + public: + using T = typename std::remove_pointer::type; + + class ConstIterator { + public: + using const_iterator = typename Container::const_iterator; + using iterator_category = std::input_iterator_tag; + using value_type = T*; + using difference_type = std::ptrdiff_t; + using pointer = T**; + using reference = T*&; + + /** Construct iterator for container that will return const T* entries.*/ + explicit ConstIterator(const_iterator position) noexcept : current_{position}, item_{nullptr} {} + ConstIterator(const ConstIterator& other) = default; + ConstIterator& operator=(const ConstIterator& other) = default; + + bool operator==(const ConstIterator& other) const noexcept { return current_ == other.current_; } + bool operator!=(const ConstIterator& other) const noexcept { return current_ != other.current_; } + + ConstIterator& operator++() { + ++current_; + return *this; + } + + ConstIterator operator++(int) { + ConstIterator tmp{*this}; + ++(*this); + return tmp; + } + + const T*& operator*() const { + item_ = *current_; + return item_; + } + + const T** operator->() const { return &(operator*()); }; + + private: + const_iterator current_; + mutable const T* item_; + }; + + /** + Construct wrapper class that will provide const access to the pointers in a container of non-const pointers. + @param data Container with non-const pointers. e.g. std::vector + */ + explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {} + + size_t size() const noexcept { return data_.size(); } + bool empty() const noexcept { return data_.empty(); } + + ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); } + ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); } + + ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); } + ConstIterator end() const noexcept { return ConstIterator(data_.cend()); } + + const T* operator[](size_t index) const { return data_[index]; } + + const T* at(size_t index) const { + ORT_ENFORCE(index < data_.size()); + return data_[index]; + } + + private: + const Container& data_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/eigen_common_wrapper.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/eigen_common_wrapper.h new file mode 100644 index 0000000..fa15399 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/eigen_common_wrapper.h @@ -0,0 +1,41 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#include "onnxruntime_config.h" +// build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71: +// error: ignoring attributes on template argument "Eigen::PacketType::type {aka __vector(4) float}" [-Werror=ignored-attributes] +#if defined(__GNUC__) +#pragma GCC diagnostic push +#if __GNUC__ >= 6 +#pragma GCC diagnostic ignored "-Wignored-attributes" +#endif +#pragma GCC diagnostic ignored "-Wunused-parameter" +#ifdef HAS_DEPRECATED_COPY +#pragma GCC diagnostic ignored "-Wdeprecated-copy" +#endif +#elif defined(_MSC_VER) +// build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): +// warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence + +// unsupported\eigen\cxx11\src\Tensor\TensorUInt128.h(150,0): Warning C4245: 'initializing': conversion from '__int64' +// to 'uint64_t', signed/unsigned mismatch +#pragma warning(push) +#pragma warning(disable : 4554) +#pragma warning(disable : 4245) +#pragma warning(disable : 4127) +#pragma warning(disable : 4805) +#pragma warning(disable : 6313) +#pragma warning(disable : 6294) +#endif + +#include "unsupported/Eigen/CXX11/Tensor" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/exceptions.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/exceptions.h new file mode 100644 index 0000000..cbebc88 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/exceptions.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/code_location.h" + +namespace onnxruntime { + +class NotImplementedException : public std::logic_error { + public: + explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; + explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; +}; + +class TypeMismatchException : public std::logic_error { + public: + TypeMismatchException() noexcept : logic_error("Type mismatch"){}; +}; + +class OnnxRuntimeException : public std::exception { + public: + OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept + : OnnxRuntimeException(location, nullptr, msg) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) + : location_{location} { + std::ostringstream ss; + + ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous + if (failed_condition != nullptr) { + ss << " " << failed_condition << " was false."; + } + + ss << " " << msg << "\n"; + if (!location.stacktrace.empty()) { + ss << "Stacktrace:\n"; + // skip the first entry in the stacktrace as we have that information from location.ToString() + std::copy(++location.stacktrace.begin(), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); + } + + what_ = ss.str(); + } + + const char* what() const noexcept override { + return what_.c_str(); + } + + private: + const CodeLocation location_; + const std::vector stacktrace_; + std::string what_; +}; + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/capture.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/capture.h new file mode 100644 index 0000000..4f71bb3 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/capture.h @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/common.h" +#include "core/common/code_location.h" +#include "core/common/logging/severity.h" + +namespace onnxruntime { +namespace logging { + +class Logger; +enum class DataType; + +/** + Class to capture the details of a log message. +*/ +class Capture { + public: + /** + Initializes a new instance of the Capture class. + @param logger The logger. + @param severity The severity. + @param category The category. + @param dataType Type of the data. + @param location The file location the log message is coming from. + */ + Capture(const Logger& logger, logging::Severity severity, const char* category, + logging::DataType dataType, const CodeLocation& location) + : logger_{&logger}, severity_{severity}, category_{category}, data_type_{dataType}, location_{location} { + } + + /** + The stream that can capture the message via operator<<. + @returns Output stream. + */ + std::ostream& Stream() noexcept { + return stream_; + } + +#ifdef _MSC_VER +// add SAL annotation for printf format string. requires Code Analysis to run to validate usage. +#define msvc_printf_check _Printf_format_string_ +#define __attribute__(x) // Disable for MSVC. Supported by GCC and CLang. +#else +#define msvc_printf_check +#endif + + /** + Captures a printf style log message. + @param name="format">The printf format. + @param name="">Arguments to the printf format if needed. + @remarks + A maximum of 2K of output will be captured currently. + Non-static method, so 'this' is implicit first arg, and we use format(printf(2,3) + */ + void CapturePrintf(msvc_printf_check const char* format, ...) __attribute__((format(printf, 2, 3))); + + /** + Process a printf style log message. + @param format The printf format. + @param ... Arguments to the printf format if needed. + @remarks + A maximum of 2K of output will be captured currently. + Note: As va_list is 'char *', we have to disambiguate this from CapturePrintf + so that something like "One string: %s", "the string" does not consider "the string" + to be the va_list. + */ + void ProcessPrintf(msvc_printf_check const char* format, va_list args); + + logging::Severity Severity() const noexcept { + return severity_; + } + + char SeverityPrefix() const noexcept { + // Carefully setup so severity_ is a valid index + GSL_SUPPRESS(bounds .2) { + return logging::SEVERITY_PREFIX[static_cast(severity_)]; + } + } + + const char* Category() const noexcept { + return category_; + } + + logging::DataType DataType() const noexcept { + return data_type_; + } + + const CodeLocation& Location() const noexcept { + return location_; + } + + std::string Message() const noexcept { + return stream_.str(); + } + + ~Capture(); + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Capture); + + const Logger* logger_; + const logging::Severity severity_; + const char* category_; + const logging::DataType data_type_; + const CodeLocation location_; + + std::ostringstream stream_; +}; +} // namespace logging +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/isink.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/isink.h new file mode 100644 index 0000000..a67777d --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/isink.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/logging/logging.h" + +namespace onnxruntime { +namespace logging { +class ISink { + public: + ISink() = default; + + /** + Sends the message to the sink. + @param timestamp The timestamp. + @param logger_id The logger identifier. + @param message The captured message. + */ + void Send(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) { + SendImpl(timestamp, logger_id, message); + } + + /** + Sends a Profiling Event Record to the sink. + @param Profiling Event Record + */ + virtual void SendProfileEvent(profiling::EventRecord&) const {}; + + virtual ~ISink() = default; + + private: + // Make Code Analysis happy by disabling all for now. Enable as needed. + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink); + + virtual void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) = 0; +}; +} // namespace logging +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/logging.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/logging.h new file mode 100644 index 0000000..0c8919f --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/logging.h @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/logging/capture.h" +#include "core/common/logging/severity.h" + +#include "core/common/logging/macros.h" + +/* + + Logging overview and expected usage: + + At program startup: + * Create one or more ISink instances. If multiple, combine using composite_sink. + * Create a LoggingManager instance with the sink/s with is_default_instance set to true + * Only one instance should be created in this way, and it should remain valid for + until the program no longer needs to produce log output. + + You can either use the static default Logger which LoggingManager will create when constructed + via LoggingManager::DefaultLogger(), or separate Logger instances each with different log ids + via LoggingManager::CreateLogger. + + The log id is passed to the ISink instance with the sink determining how the log id is used + in the output. + + LoggingManager + * creates the Logger instances used by the application + * provides a static default logger instance + * owns the log sink instance + * applies checks on severity and output of user data + + The log macros create a Capture instance to capture the information to log. + If the severity and/or user filtering settings would prevent logging, no evaluation + of the log arguments will occur, so no performance cost beyond the severity and user + filtering check. + + A sink can do further filter as needed. + +*/ + +namespace onnxruntime { +namespace profiling { + +enum EventCategory { + SESSION_EVENT = 0, + NODE_EVENT, + EVENT_CATEGORY_MAX +}; + +/* +Event descriptions for the above session events. +*/ +static constexpr const char* event_categor_names_[EVENT_CATEGORY_MAX] = { + "Session", + "Node"}; + +/* +Timing record for all events. +*/ +struct EventRecord { + EventRecord(EventCategory category, + int process_id, + int thread_id, + std::string event_name, + long long time_stamp, + long long duration, + std::unordered_map&& event_args) : cat(category), + pid(process_id), + tid(thread_id), + name(std::move(event_name)), + ts(time_stamp), + dur(duration), + args(event_args) {} + EventCategory cat; + int pid; + int tid; + std::string name; + long long ts; + long long dur; + std::unordered_map args; +}; +} // namespace profiling + +namespace logging { + +using Timestamp = std::chrono::time_point; + +#ifndef NDEBUG +ORT_ATTRIBUTE_UNUSED static bool vlog_enabled = true; // Set directly based on your needs. +#else +constexpr bool vlog_enabled = false; // no VLOG output +#endif + +enum class DataType { + SYSTEM = 0, ///< System data. + USER = 1 ///< Contains potentially sensitive user data. +}; + +// Internal log categories. +// Logging interface takes const char* so arbitrary values can also be used. +struct Category { + static const char* onnxruntime; ///< General output + static const char* System; ///< Log output regarding interactions with the host system + // TODO: What other high level categories are meaningful? Model? Optimizer? Execution? +}; + +class ISink; +class Logger; +class Capture; + +/// +/// The logging manager. +/// Owns the log sink and potentially provides a default Logger instance. +/// Provides filtering based on a minimum LogSeverity level, and of messages with DataType::User if enabled. +/// +class LoggingManager final { + public: + enum InstanceType { + Default, ///< Default instance of LoggingManager that should exist for the lifetime of the program + Temporal ///< Temporal instance. CreateLogger(...) should be used, however DefaultLogger() will NOT be provided via this instance. + }; + + /** + Initializes a new instance of the LoggingManager class. + @param sink The sink to write to. Use CompositeSink if you need to write to multiple places. + @param default_min_severity The default minimum severity. Messages with lower severity will be ignored unless + overridden in CreateLogger. + @param default_filter_user_data If set to true ignore messages with DataType::USER unless overridden in CreateLogger. + @param instance_type If InstanceType::Default, this is the default instance of the LoggingManager + and is expected to exist for the lifetime of the program. + It creates and owns the default logger that calls to the static DefaultLogger method return. + @param default_logger_id Logger Id to use for the default logger. nullptr/ignored if instance_type == Temporal. + @param default_max_vlog_level Default maximum level for VLOG messages to be created unless overridden in CreateLogger. + Requires a severity of kVERBOSE for VLOG messages to be logged. + */ + LoggingManager(std::unique_ptr sink, Severity default_min_severity, bool default_filter_user_data, + InstanceType instance_type, + const std::string* default_logger_id = nullptr, + int default_max_vlog_level = -1); + + /** + Creates a new logger instance which will use the provided logger_id and default severity and vlog levels. + @param logger_id The log identifier. + @returns A new Logger instance that the caller owns. + */ + std::unique_ptr CreateLogger(const std::string& logger_id); + + /** + Creates a new logger instance which will use the provided logger_id, severity and vlog levels. + @param logger_id The log identifier. + @param min_severity The minimum severity. Requests to create messages with lower severity will be ignored. + @param filter_user_data If set to true ignore messages with DataType::USER. + @param max_vlog_level Maximum level for VLOG messages to be created. + @returns A new Logger instance that the caller owns. + */ + std::unique_ptr CreateLogger(const std::string& logger_id, + Severity min_severity, bool filter_user_data, int max_vlog_level = -1); + + /** + Gets the default logger instance if set. Throws if no default logger is currently registered. + @remarks + Creating a LoggingManager instance with is_default_instance == true registers a default logger. + Note that the default logger is only valid until the LoggerManager that registered it is destroyed. + @returns The default logger if available. + */ + static const Logger& DefaultLogger(); + + /** + Change the minimum severity level for log messages to be output by the default logger. + @param severity The severity. + */ + static void SetDefaultLoggerSeverity(Severity severity); + + /** + Logs a FATAL level message and creates an exception that can be thrown with error information. + @param category The log category. + @param location The location the log message was generated. + @param format_str The printf format string. + @param ... The printf arguments. + @returns A new Logger instance that the caller owns. + */ + static std::exception LogFatalAndCreateException(const char* category, + const CodeLocation& location, + const char* format_str, ...); + + /** + Logs the message using the provided logger id. + @param logger_id The log identifier. + @param message The log message. + */ + void Log(const std::string& logger_id, const Capture& message) const; + + /** + Sends a Profiling Event Record to the sink. + @param Profiling Event Record + */ + void SendProfileEvent(profiling::EventRecord& eventRecord) const; + ~LoggingManager(); + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LoggingManager); + + Timestamp GetTimestamp() const noexcept; + void CreateDefaultLogger(const std::string& logger_id); + + std::unique_ptr sink_; + const Severity default_min_severity_; + const bool default_filter_user_data_; + const int default_max_vlog_level_; + bool owns_default_logger_; + + static Logger* s_default_logger_; + + struct Epochs { + const std::chrono::time_point high_res; + const std::chrono::time_point system; + const std::chrono::minutes localtime_offset_from_utc; + }; + + static const Epochs& GetEpochs() noexcept; +}; + +/** + Logger provides a per-instance log id. Everything else is passed back up to the LoggingManager +*/ +class Logger { + public: + /** + Initializes a new instance of the Logger class. + @param loggingManager The logging manager. + @param id The identifier for messages coming from this Logger. + @param severity Minimum severity for messages to be created and logged. + @param filter_user_data Should USER data be filtered from output. + @param vlog_level Minimum level for VLOG messages to be created. Note that a severity of kVERBOSE must be provided + for VLOG messages to be logged. + */ + Logger(const LoggingManager& loggingManager, std::string id, + Severity severity, bool filter_user_data, int vlog_level) + : logging_manager_{&loggingManager}, + id_{id}, + min_severity_{severity}, + filter_user_data_{filter_user_data}, + max_vlog_level_{severity > Severity::kVERBOSE ? -1 : vlog_level} { // disable unless logging VLOG messages + } + + /** + Get the minimum severity level for log messages to be output. + @returns The severity. + */ + Severity GetSeverity() const noexcept { return min_severity_; } + + /** + Change the minimum severity level for log messages to be output. + @param severity The severity. + */ + void SetSeverity(Severity severity) noexcept { min_severity_ = severity; } + + /** + Check if output is enabled for the provided LogSeverity and DataType values. + @param severity The severity. + @param data_type Type of the data. + @returns True if a message with these values will be logged. + */ + bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { + return (severity >= min_severity_ && (data_type != DataType::USER || !filter_user_data_)); + } + + /** + Return the maximum VLOG level allowed. + */ + int VLOGMaxLevel() const noexcept { + return max_vlog_level_; + } + + /** + Logs the captured message. + @param message The log message. + */ + void Log(const Capture& message) const { + logging_manager_->Log(id_, message); + } + + /** + Sends a Profiling Event Record to the sink. + @param Profiling Event Record + */ + void SendProfileEvent(profiling::EventRecord& eventRecord) const { + logging_manager_->SendProfileEvent(eventRecord); + } + + private: + const LoggingManager* logging_manager_; + const std::string id_; + Severity min_severity_; + const bool filter_user_data_; + const int max_vlog_level_; +}; + +inline const Logger& LoggingManager::DefaultLogger() { + if (s_default_logger_ == nullptr) { + // fail early for attempted misuse. don't use logging macros as we have no logger. + throw std::logic_error("Attempt to use DefaultLogger but none has been registered."); + } + + return *s_default_logger_; +} + +inline void LoggingManager::SetDefaultLoggerSeverity(Severity severity) { + if (s_default_logger_ == nullptr) { + // fail early for attempted misuse. don't use logging macros as we have no logger. + throw std::logic_error("Attempt to use DefaultLogger but none has been registered."); + } + + s_default_logger_->SetSeverity(severity); +} + +inline Timestamp LoggingManager::GetTimestamp() const noexcept { + static const Epochs& epochs = GetEpochs(); + + const auto high_res_now = std::chrono::high_resolution_clock::now(); + return std::chrono::time_point_cast( + epochs.system + (high_res_now - epochs.high_res) + epochs.localtime_offset_from_utc); +} + +/** + Return the current thread id. +*/ +unsigned int GetThreadId(); + +/** + Return the current process id. +*/ +unsigned int GetProcessId(); + +} // namespace logging +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/macros.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/macros.h new file mode 100644 index 0000000..570bc14 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/macros.h @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +// NOTE: Don't include this file directly. Include logging.h + +#define CREATE_MESSAGE(logger, severity, category, datatype) \ + ::onnxruntime::logging::Capture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE) + +/* + Both printf and stream style logging are supported. + Not that printf currently has a 2K limit to the message size. + + LOGS_* macros are for stream style + LOGF_* macros are for printf style + + The Message class captures the log input, and pushes it through the logger in its destructor. + + Use the *FATAL* macros if you want a Severity::kFatal message to also throw. + + There are a few variants to minimize the length of the macro name required in the calling code. + They are optimized so the shortest names are for the (expected) most common usage. This can be + tweaked if needed. + + Explicit logger vs LoggingManager::DefaulLogger() + Default is for a logger instance to be explicitly passed in. + The logger instance provides an identifier so that log messages from different runs can be separated. + + Variants with DEFAULT in the macro name use the default logger provided by logging manager. This is + static so accessible from any code, provided a LoggingManager instance created with InstanceType::Default + exists somewhere. See logging.h for further explanation of the expected setup. + + DataType + Default uses DataType::SYSTEM. + + Variants with USER in the macro name use DataType::USER. This is data that could be PII, and may need to + be filtered from output. LoggingManager applies this filtering. + + Category + Default category is ::onnxruntime::Logging::Category::onnxruntime. + + If you wish to provide a different category, use variants with CATEGORY in the macro name + +*/ + +// Logging with explicit category + +// iostream style logging. Capture log info in Message, and push to the logger in ~Message. +#define LOGS_CATEGORY(logger, severity, category) \ + if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::SYSTEM)) \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM).Stream() + +#define LOGS_USER_CATEGORY(logger, severity, category) \ + if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::USER)) \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER).Stream() + + // printf style logging. Capture log info in Message, and push to the logger in ~Message. +#define LOGF_CATEGORY(logger, severity, category, format_str, ...) \ + if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::SYSTEM)) \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM).CapturePrintf(format_str, ##__VA_ARGS__) + +#define LOGF_USER_CATEGORY(logger, severity, category, format_str, ...) \ + if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::USER)) \ + CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER).CapturePrintf(format_str, ##__VA_ARGS__) + + // Logging with category of "onnxruntime" + +#define LOGS(logger, severity) \ + LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER(logger, severity) \ + LOGS_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) + + // printf style logging. Capture log info in Message, and push to the logger in ~Message. +#define LOGF(logger, severity, format_str, ...) \ + LOGF_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_USER(logger, severity, format_str, ...) \ + LOGF_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + + /* + + Macros that use the default logger. + A LoggingManager instance must be currently valid for the default logger to be available. + + */ + + // Logging with explicit category + +#define LOGS_DEFAULT_CATEGORY(severity, category) \ + LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category) + +#define LOGS_USER_DEFAULT_CATEGORY(severity, category) \ + LOGS_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category) + +#define LOGF_DEFAULT_CATEGORY(severity, category, format_str, ...) \ + LOGF_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__) + +#define LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ...) \ + LOGF_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__) + +// Logging with category of "onnxruntime" + +#define LOGS_DEFAULT(severity) \ + LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER_DEFAULT(severity) \ + LOGS_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGF_DEFAULT(severity, format_str, ...) \ + LOGF_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_USER_DEFAULT(severity, format_str, ...) \ + LOGF_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + + /* + + Conditional logging + + */ + + // Logging with explicit category + +#define LOGS_CATEGORY_IF(boolean_expression, logger, severity, category) \ + if ((boolean_expression) == true) LOGS_CATEGORY(logger, severity, category) + +#define LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \ + if ((boolean_expression) == true) LOGS_DEFAULT_CATEGORY(severity, category) + +#define LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, category) \ + if ((boolean_expression) == true) LOGS_USER_CATEGORY(logger, severity, category) + +#define LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \ + if ((boolean_expression) == true) LOGS_USER_DEFAULT_CATEGORY(severity, category) + +#define LOGF_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \ + if ((boolean_expression) == true) LOGF_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__) + +#define LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \ + if ((boolean_expression) == true) LOGF_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__) + +#define LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \ + if ((boolean_expression) == true) LOGF_USER_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__) + +#define LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \ + if ((boolean_expression) == true) LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__) + + // Logging with category of "onnxruntime" + +#define LOGS_IF(boolean_expression, logger, severity) \ + LOGS_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_DEFAULT_IF(boolean_expression, severity) \ + LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER_IF(boolean_expression, logger, severity) \ + LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGS_USER_DEFAULT_IF(boolean_expression, severity) \ + LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime) + +#define LOGF_IF(boolean_expression, logger, severity, format_str, ...) \ + LOGF_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_DEFAULT_IF(boolean_expression, severity, format_str, ...) \ + LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) + +#define LOGF_USER_IF(boolean_expression, logger, severity, format_str, ...) \ + LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, \ + format_str, ##__VA_ARGS__) + +#define LOGF_USER_DEFAULT_IF(boolean_expression, severity, format_str, ...) \ + LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, \ + format_str, ##__VA_ARGS__) + +/* + + Debug verbose logging of caller provided level. + Disabled in Release builds. + Use the _USER variants for VLOG statements involving user data that may need to be filtered. +*/ +#define VLOGS(logger, level) \ + if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ + LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level) + +#define VLOGS_USER(logger, level) \ + if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ + LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level) + +#define VLOGF(logger, level, format_str, ...) \ + if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ + LOGF_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__) + +#define VLOGF_USER(logger, level, format_str, ...) \ + if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ + LOGF_USER_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__) + + // Default logger variants +#define VLOGS_DEFAULT(level) \ + VLOGS(::onnxruntime::logging::LoggingManager::DefaultLogger(), level) + +#define VLOGS_USER_DEFAULT(level) \ + VLOGS_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level) + +#define VLOGF_DEFAULT(level, format_str, ...) \ + VLOGF(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__) + +#define VLOGF_USER_DEFAULT(level, format_str, ...) \ + VLOGF_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__) diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/severity.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/severity.h new file mode 100644 index 0000000..e43f192 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/logging/severity.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace logging { +// mild violation of naming convention. the 'k' lets us use token concatenation in the macro +// ::onnxruntime::Logging::Severity::k##severity. It's not legal to have ::onnxruntime::Logging::Severity::##severity +// the uppercase makes the LOG macro usage look as expected for passing an enum value as it will be LOGS(logger, ERROR) +enum class Severity { + kVERBOSE = 0, + kINFO = 1, + kWARNING = 2, + kERROR = 3, + kFATAL = 4 +}; + +constexpr const char* SEVERITY_PREFIX = "VIWEF"; + +} // namespace logging +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/make_unique.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/make_unique.h new file mode 100644 index 0000000..b401f0d --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/make_unique.h @@ -0,0 +1,148 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: make_unique.h +// ----------------------------------------------------------------------------- +// +// This header file contains utility functions for managing the creation and +// conversion of smart pointers. This file is an extension to the C++ +// standard library header file. +/* Modifications Copyright (c) Microsoft. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { + +template +using remove_extent_t = typename std::remove_extent::type; + +namespace memory_internal { + +// Traits to select proper overload and return type for `absl::make_unique<>`. +template +struct MakeUniqueResult { + using scalar = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using array = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using invalid = void; +}; + +} // namespace memory_internal + +// gcc 4.8 has __cplusplus at 201301 but doesn't define make_unique. Other +// supported compilers either just define __cplusplus as 201103 but have +// make_unique (msvc), or have make_unique whenever __cplusplus > 201103 (clang) +#if (__cplusplus > 201103L || defined(_MSC_VER)) && \ + !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) +using std::make_unique; +#else +// ----------------------------------------------------------------------------- +// Function Template: make_unique() +// ----------------------------------------------------------------------------- +// +// Creates a `std::unique_ptr<>`, while avoiding issues creating temporaries +// during the construction process. `absl::make_unique<>` also avoids redundant +// type declarations, by avoiding the need to explicitly use the `new` operator. +// +// This implementation of `absl::make_unique<>` is designed for C++11 code and +// will be replaced in C++14 by the equivalent `std::make_unique<>` abstraction. +// `absl::make_unique<>` is designed to be 100% compatible with +// `std::make_unique<>` so that the eventual migration will involve a simple +// rename operation. +// +// For more background on why `std::unique_ptr(new T(a,b))` is problematic, +// see Herb Sutter's explanation on +// (Exception-Safe Function Calls)[https://herbsutter.com/gotw/_102/]. +// (In general, reviewers should treat `new T(a,b)` with scrutiny.) +// +// Example usage: +// +// auto p = make_unique(args...); // 'p' is a std::unique_ptr +// auto pa = make_unique(5); // 'pa' is a std::unique_ptr +// +// Three overloads of `absl::make_unique` are required: +// +// - For non-array T: +// +// Allocates a T with `new T(std::forward args...)`, +// forwarding all `args` to T's constructor. +// Returns a `std::unique_ptr` owning that object. +// +// - For an array of unknown bounds T[]: +// +// `absl::make_unique<>` will allocate an array T of type U[] with +// `new U[n]()` and return a `std::unique_ptr` owning that array. +// +// Note that 'U[n]()' is different from 'U[n]', and elements will be +// value-initialized. Note as well that `std::unique_ptr` will perform its +// own destruction of the array elements upon leaving scope, even though +// the array [] does not have a default destructor. +// +// NOTE: an array of unknown bounds T[] may still be (and often will be) +// initialized to have a size, and will still use this overload. E.g: +// +// auto my_array = absl::make_unique(10); +// +// - For an array of known bounds T[N]: +// +// `absl::make_unique<>` is deleted (like with `std::make_unique<>`) as +// this overload is not useful. +// +// NOTE: an array of known bounds T[N] is not considered a useful +// construction, and may cause undefined behavior in templates. E.g: +// +// auto my_array = absl::make_unique(); +// +// In those cases, of course, you can still use the overload above and +// simply initialize it to its desired size: +// +// auto my_array = absl::make_unique(10); + +// `absl::make_unique` overload for non-array types. +template +typename memory_internal::MakeUniqueResult::scalar make_unique( + Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +// `absl::make_unique` overload for an array T[] of unknown bounds. +// The array allocation needs to use the `new T[size]` form and cannot take +// element constructor arguments. The `std::unique_ptr` will manage destructing +// these array elements. +template +typename memory_internal::MakeUniqueResult::array make_unique(size_t n) { + return std::unique_ptr(new typename onnxruntime::remove_extent_t[n]()); +} + +// `absl::make_unique` overload for an array T[N] of known bounds. +// This construction will be rejected. +template +typename memory_internal::MakeUniqueResult::invalid make_unique( + Args&&... /* args */) = delete; +#endif + +} \ No newline at end of file diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/optional.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/optional.h new file mode 100644 index 0000000..33cb41e --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/optional.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" + +namespace onnxruntime { +// This is a version of std::optional with limited functionality and plenty of +// room to improve. We should use std::optional when we move to C++17. +template +class optional { + public: + optional() : has_value_{false}, value_{} {} + + optional(const optional&) = default; + optional& operator=(const optional&) = default; + optional(optional&&) = default; + optional& operator=(optional&&) = default; + + optional(T value) : has_value_{true}, value_{value} {} + optional& operator=(T value) { + has_value_ = true; + value_ = value; + return *this; + } + + bool has_value() const { return has_value_; } + const T& value() const { + ORT_ENFORCE(has_value_); + return value_; + } + T& value() { + ORT_ENFORCE(has_value_); + return value_; + } + + private: + bool has_value_; + T value_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/status.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/status.h new file mode 100644 index 0000000..c107655 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/common/status.h @@ -0,0 +1,191 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Modifications Copyright (c) Microsoft. + +#pragma once + +#include +#include +#include +#ifdef _WIN32 +#include +#endif + +namespace onnxruntime { +namespace common { + +enum StatusCategory { + NONE = 0, + SYSTEM = 1, + ONNXRUNTIME = 2, +}; + +/** + Error code for ONNXRuntime. +*/ +enum StatusCode { + OK = 0, + FAIL = 1, + INVALID_ARGUMENT = 2, + NO_SUCHFILE = 3, + NO_MODEL = 4, + ENGINE_ERROR = 5, + RUNTIME_EXCEPTION = 6, + INVALID_PROTOBUF = 7, + MODEL_LOADED = 8, + NOT_IMPLEMENTED = 9, + INVALID_GRAPH = 10, + EP_FAIL = 11 +}; + +inline const char* StatusCodeToString(StatusCode status) noexcept { + switch (status) { + case StatusCode::OK: + return "SUCCESS"; + case StatusCode::FAIL: + return "FAIL"; + case StatusCode::INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case StatusCode::NO_SUCHFILE: + return "NO_SUCHFILE"; + case StatusCode::NO_MODEL: + return "NO_MODEL"; + case StatusCode::ENGINE_ERROR: + return "ENGINE_ERROR"; + case StatusCode::RUNTIME_EXCEPTION: + return "RUNTIME_EXCEPTION"; + case StatusCode::INVALID_PROTOBUF: + return "INVALID_PROTOBUF"; + case StatusCode::MODEL_LOADED: + return "MODEL_LOADED"; + case StatusCode::NOT_IMPLEMENTED: + return "NOT_IMPLEMENTED"; + case StatusCode::INVALID_GRAPH: + return "INVALID_GRAPH"; + case StatusCode::EP_FAIL: + return "EP_FAIL"; + default: + return "GENERAL ERROR"; + } +} + +#ifdef _WIN32 +inline HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { + switch (status) + { + case StatusCode::OK: + return S_OK; + case StatusCode::FAIL: + return E_FAIL; + case StatusCode::INVALID_ARGUMENT: + return E_INVALIDARG; + case StatusCode::NO_SUCHFILE: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::NO_MODEL: + return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + case StatusCode::ENGINE_ERROR: + return E_FAIL; + case StatusCode::RUNTIME_EXCEPTION: + return E_FAIL; + case StatusCode::INVALID_PROTOBUF: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::MODEL_LOADED: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::NOT_IMPLEMENTED: + return E_NOTIMPL; + case StatusCode::INVALID_GRAPH: + return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); + case StatusCode::EP_FAIL: + return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + default: + return E_FAIL; + } +} +#endif + +class Status { + public: + Status() noexcept = default; + + Status(StatusCategory category, int code, const std::string& msg); + + Status(StatusCategory category, int code, const char* msg); + + Status(StatusCategory category, int code); + + Status(const Status& other) + : state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {} + + Status& operator=(const Status& other) { + if (state_ != other.state_) { + if (other.state_ == nullptr) { + state_.reset(); + } else { + state_.reset(new State(*other.state_)); + } + } + return *this; + } + + Status(Status&&) = default; + Status& operator=(Status&&) = default; + ~Status() = default; + + bool IsOK() const { + return (state_ == nullptr); + } + + int Code() const noexcept; + + StatusCategory Category() const noexcept; + + const std::string& ErrorMessage() const noexcept; + + std::string ToString() const; + + bool operator==(const Status& other) const { + return (this->state_ == other.state_) || (ToString() == other.ToString()); + } + + bool operator!=(const Status& other) const { + return !(*this == other); + } + + static Status OK() { + return Status(); + } + + private: + static const std::string& EmptyString() noexcept; + + struct State { + State(StatusCategory cat0, int code0, const std::string& msg0) + : category(cat0), code(code0), msg(msg0) {} + + State(StatusCategory cat0, int code0, const char* msg0) + : category(cat0), code(code0), msg(msg0) {} + + const StatusCategory category; + const int code; + const std::string msg; + }; + + // As long as Code() is OK, state_ == nullptr. + std::unique_ptr state_; +}; + +inline std::ostream& operator<<(std::ostream& out, const Status& status) { + return out << status.ToString(); +} + +} // namespace common +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/alloc_kind.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/alloc_kind.h new file mode 100644 index 0000000..a749e6b --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/alloc_kind.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +// The ml-Values fall into the following categories with respect to their +// memory management: +// - inference inputs: owned (allocated and freed) by caller, and is by +// default read-only by the runtime. +// - inference outputs: allocated by runtime, ownership transferred to +// caller. TODO: Make sure this semantics is clear in InferenceSession API. +// - weights (constant tensors): can be allocated once (statically), and +// reused by all inference calls within an InferenceSession. +// - tensor values: The lifetimes of these tensor-values are statically +// determined, which is used for memory reuse/sharing optimizations. The +// runtime allocates/frees these values at the right time (as determined +// by the static allocation plan). Note that this is simplified since we +// do not try to optimize for "slice" like ops, where we may be able to +// conditionally reuse memory/data in some cases but not others. +// Generalizing this is future work. + +enum class AllocKind { + kAllocate = 0, + kReuse = 1, + kPreExisting = 2, + kAllocateStatically = 3, + kAllocateOutput = 4, + kShare = 5 +}; + +std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind); +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/allocator.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/allocator.h new file mode 100644 index 0000000..111d743 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/allocator.h @@ -0,0 +1,304 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/exceptions.h" +#include "core/common/status.h" +#include "core/framework/fence.h" +#include "core/session/onnxruntime_c_api.h" + +// Struct to represent a physical device. +struct OrtDevice { + using DeviceType = int8_t; + using MemoryType = int8_t; + using DeviceId = int16_t; + + // Pre-defined device types. + static const DeviceType CPU = 0; + static const DeviceType GPU = 1; //CUDA or HIP + static const DeviceType FPGA = 2; + + struct MemType { + // Pre-defined memory types. + static const MemoryType DEFAULT = 0; + static const MemoryType CUDA_PINNED = 1; + static const MemoryType HIP_PINNED = 2; + }; + + constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) + : device_type(device_type_), + memory_type(memory_type_), + device_id(device_id_) {} + + constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {} + + DeviceType Type() const { + return device_type; + } + + MemoryType MemType() const { + return memory_type; + } + + DeviceId Id() const { + return device_id; + } + + std::string ToString() const { + std::ostringstream ostr; + ostr << "Device:[" + << "DeviceType:" << static_cast(device_type) + << " MemoryType:" << static_cast(memory_type) + << " DeviceId:" << device_id + << "]"; + return ostr.str(); + } + + private: + // Device type. + DeviceType device_type; + + // Memory type. + MemoryType memory_type; + + // Device index. + DeviceId device_id; +}; + +inline bool operator==(const OrtDevice& left, const OrtDevice& other) { + return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type(); +} + +inline bool operator!=(const OrtDevice& left, const OrtDevice& other) { + return !(left == other); +} + +struct OrtMemoryInfo { + OrtMemoryInfo() = default; // to allow default construction of Tensor + + // use string for name, so we could have customized allocator in execution provider. + const char* name = nullptr; + int id = -1; + OrtMemType mem_type = OrtMemTypeDefault; + OrtAllocatorType alloc_type = Invalid; + OrtDevice device; + + constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, + OrtMemType mem_type_ = OrtMemTypeDefault) +#if ((defined(__GNUC__) && __GNUC__ > 4) || defined(__clang__)) + // this causes a spurious error in CentOS gcc 4.8 build so disable if GCC version < 5 + __attribute__((nonnull)) +#endif + : name(name_), + id(id_), + mem_type(mem_type_), + alloc_type(type_), + device(device_) { + } + + // To make OrtMemoryInfo become a valid key in std map + bool operator<(const OrtMemoryInfo& other) const { + if (alloc_type != other.alloc_type) + return alloc_type < other.alloc_type; + if (mem_type != other.mem_type) + return mem_type < other.mem_type; + if (id != other.id) + return id < other.id; + + return strcmp(name, other.name) < 0; + } + + std::string ToString() const { + std::ostringstream ostr; + ostr << "OrtMemoryInfo:[" + << "name:" << name + << " id:" << id + << " OrtMemType:" << mem_type + << " OrtAllocatorType:" << alloc_type + << " " << device.ToString() + << "]"; + return ostr.str(); + } +}; + +inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) { + return left.mem_type == other.mem_type && + left.alloc_type == other.alloc_type && + left.id == other.id && + strcmp(left.name, other.name) == 0; +} + +inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { return !(lhs == rhs); } + +std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info); + +namespace onnxruntime { +constexpr const char* CPU = "Cpu"; +constexpr const char* CUDA = "Cuda"; +constexpr const char* CUDA_PINNED = "CudaPinned"; +constexpr const char* MIGRAPHX = "MIGraphX"; +constexpr const char* MIGRAPHX_PINNED = "MIGraphXPinned"; +constexpr const char* TRT = "Tensorrt"; +constexpr const char* TRT_PINNED = "TensorrtPinned"; + +// forward declaration +class SessionState; + +template +using IAllocatorUniquePtr = std::unique_ptr>; + +class IAllocator { + public: + IAllocator(const OrtMemoryInfo& info) : memory_info_(info) {} + virtual ~IAllocator() = default; + /** + @remarks Use SafeInt when calculating the size of memory to allocate using Alloc. + */ + virtual void* Alloc(size_t size) = 0; + virtual void Free(void* p) = 0; + const OrtMemoryInfo& Info() const { return memory_info_; }; + + /** + optional CreateFence interface, as provider like DML has its own fence + */ + virtual FencePtr CreateFence(const SessionState* /*unused*/) { return nullptr; } + + static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept { + return CalcMemSizeForArrayWithAlignment(nmemb, size, 0, out); + } + + /** + * Calculate the memory size for an array. The size is bounds checked using SafeInt. + * \tparam alignment must be power of 2 + * \param nmemb Number of members or elements in the array + * \param size Size of each element + * \param out Total size required after any alignment is applied + * \return true, successful. false, overflow + */ + static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept ORT_MUST_USE_RESULT; + + /** + * https://cwe.mitre.org/data/definitions/190.html + * \param alignment must be power of 2 + * \param nmemb Number of members or elements in the array + * \param size Size of each element + * \param out Total size required after any alignment is applied + * \return true, successful. false, overflow + * \remarks This was the original API and was implemented in the header. Replaced with the above version + * implemented in the .cc file so that the SafeInt dependency is internal. + */ + template + static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept ORT_MUST_USE_RESULT; + + /** + * allocate memory for an array which has nmemb items of data, each size bytes long + */ + void* AllocArray(size_t nmemb, size_t size) { + size_t len; + if (!CalcMemSizeForArray(nmemb, size, &len)) + return nullptr; + return Alloc(len); + } + + /** + * allocate memory for an array which has nmemb items of data, each size bytes long + */ + template + void* AllocArrayWithAlignment(size_t nmemb, size_t size) { + size_t len; + if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len)) + return nullptr; + return Alloc(len); + } + + /** + Create a std::unique_ptr that is allocated and freed by the provided IAllocator. + @param allocator The allocator. + @param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate. + @returns std::unique_ptr with allocated memory and deleter. + */ + template + static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes) { + if (allocator == nullptr) return nullptr; + // for now limit to fundamental types. we could support others, but to do so either we or the caller + // needs to call the dtor for the objects, for buffers allocated on device we don't have destructor + //static_assert(std::is_fundamental::value, "Fundamental type required as no destructors are called."); + + size_t alloc_size = count_or_bytes; + + // if T is not void, 'count_or_bytes' == number of items so allow for that + if (!std::is_void::value) { + // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't + // reachable if T is void. use std::conditional to 'use' void* in the sizeof call + if (!CalcMemSizeForArray(count_or_bytes, + sizeof(typename std::conditional::value, void*, T>::type), + &alloc_size)) return nullptr; + } + + return IAllocatorUniquePtr{ + static_cast(allocator->Alloc(alloc_size)), // allocate + [=](T* ptr) { // capture 'allocator' by value so it's always valid + allocator->Free(ptr); + }}; + } + + private: + OrtMemoryInfo memory_info_; +}; + +template +bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept { + return CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, out); +} + +/** + The resource allocator on a physical device. + This allocator will directly allocate resource from system call +*/ +class IDeviceAllocator : public IAllocator { + public: + IDeviceAllocator(const OrtMemoryInfo& info) : IAllocator(info) {} + ~IDeviceAllocator() override = default; + void* Alloc(size_t size) override = 0; + void Free(void* p) override = 0; +}; + +class CPUAllocator : public IDeviceAllocator { + public: + explicit CPUAllocator(const OrtMemoryInfo& memory_info) : IDeviceAllocator(memory_info) {} + + CPUAllocator() : IDeviceAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} + + void* Alloc(size_t size) override; + void Free(void* p) override; +}; + +#if defined(USE_MIMALLOC_ARENA_ALLOCATOR) +class MiMallocAllocator : public IDeviceAllocator { + public: + explicit MiMallocAllocator(const OrtMemoryInfo& memory_info) : IDeviceAllocator(memory_info) {} + MiMallocAllocator() : IDeviceAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} + + void* Alloc(size_t size) override; + void Free(void* p) override; +}; + +#endif + +#if defined(USE_MIMALLOC_ARENA_ALLOCATOR) +using TAllocator = MiMallocAllocator; +#else +using TAllocator = CPUAllocator; +#endif + +using AllocatorPtr = std::shared_ptr; + +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/customregistry.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/customregistry.h new file mode 100644 index 0000000..aafe5a5 --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/customregistry.h @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" +#include "core/common/logging/logging.h" +#include "core/graph/schema_registry.h" +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_def_builder.h" +#include "core/framework/kernel_registry.h" + +namespace onnxruntime { + +/** + Represents a registry that contains both custom kernels and custom schemas. +*/ +class CustomRegistry final { + public: + CustomRegistry() : + kernel_registry_(std::make_shared()), + opschema_registry_(std::make_shared()) {} + + /** + * Register a kernel definition together with kernel factory method to this session. + * If any conflict happened between registered kernel def and built-in kernel def, + * registered kernel will have higher priority. + * Call this before invoking Initialize(). + * @return OK if success. + */ + common::Status RegisterCustomKernel(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator); + + common::Status RegisterCustomKernel(KernelCreateInfo&); + + common::Status RegisterOpSet(std::vector& schemas, const std::string& domain, + int baseline_opset_version, int opset_version); + + const std::shared_ptr& GetKernelRegistry(); + + const std::shared_ptr& GetOpschemaRegistry(); + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry); + std::shared_ptr kernel_registry_; + std::shared_ptr opschema_registry_; + +}; + +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types.h new file mode 100644 index 0000000..fab910f --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types.h @@ -0,0 +1,1013 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/exceptions.h" +#include "core/framework/endian.h" +#include "core/graph/onnx_protobuf.h" + +struct OrtValue; + +namespace ONNX_NAMESPACE { +class TypeProto; +} // namespace ONNX_NAMESPACE + +namespace onnxruntime { +/// Predefined registered types + +//maps +using MapStringToString = std::map; +using MapStringToInt64 = std::map; +using MapStringToFloat = std::map; +using MapStringToDouble = std::map; +using MapInt64ToString = std::map; +using MapInt64ToInt64 = std::map; +using MapInt64ToFloat = std::map; +using MapInt64ToDouble = std::map; + +//vectors/sequences +using VectorMapStringToFloat = std::vector; +using VectorMapInt64ToFloat = std::vector; +using VectorString = std::vector; +using VectorInt64 = std::vector; + +class DataTypeImpl; +class TensorTypeBase; +class SparseTensorTypeBase; +class SequenceTensorTypeBase; +class NonTensorTypeBase; +class PrimitiveDataTypeBase; + +// MLFloat16 +union MLFloat16 { + uint16_t val; + + explicit MLFloat16(uint16_t x) : val(x) {} + MLFloat16() : val(0) {} + + // Taken from https://stackoverflow.com/a/60047308/12627730 + float AsFloat(uint32_t x) const { + float out = 0.0f; + std::memcpy(&out, &x, sizeof(x)); + return out; + } + + // Taken from https://stackoverflow.com/a/60047308/12627730 + uint32_t AsUint(float x) const { + uint32_t out = 0; + std::memcpy(&out, &x, sizeof(x)); + return out; + } + + float HalfToFloat(const uint16_t x) const { + uint16_t half = x; + if (endian::native == endian::big) { + // Taken from https://stackoverflow.com/a/2182184/12627730 + half = (x >> 8) | (x << 8); + } + + // Taken from https://stackoverflow.com/a/60047308/12627730 + // IEEE-754 16-bit floating-point format (without infinity): 1-5-10, exp-15, +-131008.0, +-6.1035156E-5, + // +-5.9604645E-8, 3.311 digits + const uint32_t e = (half & 0x7C00) >> 10; // exponent + const uint32_t m = (half & 0x03FF) << 13; // mantissa + // evil log2 bit hack to count leading zeros in denormalized format + const uint32_t v = AsUint(static_cast(m)) >> 23; + uint32_t full = (half & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) | + ((e == 0) & (m != 0)) * ((v - 37) << 23 | ((m << (150 - v)) & 0x007FE000)); // sign : normalized : denormalized + + if (endian::native == endian::big) { + // Taken from https://stackoverflow.com/a/2182184/12627730 + full = ((full >> 24) & 0xff) | // move byte 3 to byte 0 + ((full << 8) & 0xff0000) | // move byte 1 to byte 2 + ((full >> 8) & 0xff00) | // move byte 2 to byte 1 + ((full << 24) & 0xff000000); // byte 0 to byte 3 + } + + return AsFloat(full); + } + + operator float() const { + return HalfToFloat(val); + } +}; + +inline bool operator==(const MLFloat16& left, const MLFloat16& right) { + return left.val == right.val; +} + +inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { + return left.val != right.val; +} + +inline bool operator<(const MLFloat16& left, const MLFloat16& right) { + return left.val < right.val; +} + +//BFloat16 +struct BFloat16 { + uint16_t val{0}; + explicit BFloat16() = default; + explicit BFloat16(uint16_t v) : val(v) {} + explicit BFloat16(float v) { + if (endian::native == endian::little) { + std::memcpy(&val, reinterpret_cast(&v) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&val, &v, sizeof(uint16_t)); + } + } + + float ToFloat() const { + float result; + char* const first = reinterpret_cast(&result); + char* const second = first + sizeof(uint16_t); + if (endian::native == endian::little) { + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; + } +}; + +inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) BFloat16(*src); + } +} + +inline bool operator==(const BFloat16& left, const BFloat16& right) { + return left.val == right.val; +} + +inline bool operator!=(const BFloat16& left, const BFloat16& right) { + return left.val != right.val; +} + +inline bool operator<(const BFloat16& left, const BFloat16& right) { + return left.val < right.val; +} + +// DataTypeImpl pointer as unique DataTypeImpl identifier. +using MLDataType = const DataTypeImpl*; +// be used with class MLValue +using DeleteFunc = void (*)(void*); +using CreateFunc = void* (*)(); + +/** + * \brief Base class for MLDataType + * + */ +class DataTypeImpl { + public: + virtual ~DataTypeImpl() = default; + + /** + * \brief this API will be used to check type compatibility at runtime + * + * \param type_proto a TypeProto instance that is constructed for a specific type + * will be checked against a TypeProto instance contained within a corresponding + * MLDataType instance. + */ + virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0; + + virtual size_t Size() const = 0; + + virtual DeleteFunc GetDeleteFunc() const = 0; + + /** + * \brief Retrieves an instance of TypeProto for + * a given MLDataType + * \returns optional TypeProto. Only ONNX types + has type proto, non-ONNX types will return nullptr. + */ + virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0; + + virtual bool IsTensorType() const { + return false; + } + + virtual bool IsTensorSequenceType() const { + return false; + } + + virtual bool IsSparseTensorType() const { + return false; + } + + // Returns this if this is of tensor-type and null otherwise + virtual const TensorTypeBase* AsTensorType() const { + return nullptr; + } + + virtual const SequenceTensorTypeBase* AsSequenceTensorBase() const { + return nullptr; + } + + // Returns this if this is of sparse-tensor-type and null otherwise + virtual const SparseTensorTypeBase* AsSparseTensorType() const { + return nullptr; + } + + virtual const NonTensorTypeBase* AsNonTensorTypeBase() const { + return nullptr; + } + + // Returns this if this is one of the primitive data types (specialization of PrimitiveDataTypeBase) + // and null otherwise + virtual const PrimitiveDataTypeBase* AsPrimitiveDataType() const { + return nullptr; + } + + // Return the type meta that we are using in the runtime. + template + static MLDataType GetType(); + + // Return the types for a concrete tensor type, like Tensor_Float + template + static MLDataType GetTensorType(); + + template + static MLDataType GetSequenceTensorType(); + + // Return the MLDataType for a concrete sparse tensor type. + template + static MLDataType GetSparseTensorType(); + + /** + * Convert an ONNX TypeProto to onnxruntime DataTypeImpl. + * However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back. + * Even though GetTypeProto() will not have the original information, it will still have enough to correctly + * map to MLDataType. + * \param proto + */ + static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto); + + static const TensorTypeBase* TensorTypeFromONNXEnum(int type); + static const SparseTensorTypeBase* SparseTensorTypeFromONNXEnum(int type); + static const NonTensorTypeBase* SequenceTensorTypeFromONNXEnum(int type); + + static const char* ToString(MLDataType type); + // Registers ONNX_NAMESPACE::DataType (internalized string) with + // MLDataType. DataType is produced by internalizing an instance of + // TypeProto contained within MLDataType + static void RegisterDataType(MLDataType); + static MLDataType GetDataType(const std::string&); + + static const std::vector& AllTensorTypes(); + static const std::vector& AllSequenceTensorTypes(); + static const std::vector& AllFixedSizeTensorTypes(); + static const std::vector& AllNumericTensorTypes(); + static const std::vector& AllIEEEFloatTensorTypes(); + static const std::vector& AllFixedSizeTensorExceptHalfTypes(); + static const std::vector& AllIEEEFloatTensorExceptHalfTypes(); +}; + +std::ostream& operator<<(std::ostream& out, MLDataType data_type); + +/* + * Type registration helpers + */ +namespace data_types_internal { +/// TensorType helpers +/// + +template +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType(); + +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; +} +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT8; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT8; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT16; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT16; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT32; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT64; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_STRING; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_BOOL; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT32; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT64; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorDataType() { + return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; +} + +// There is a specialization only for one +// type argument. +template +struct TensorElementTypeSetter { + static void SetTensorElementType(ONNX_NAMESPACE::TypeProto&); + static void SetMapKeyType(ONNX_NAMESPACE::TypeProto&); + static int32_t GetElementType(); +}; + +/// Is a given type on the list of types? +/// Accepts a list of types and the first argument is the type +/// We are checking if it is listed among those that follow +template +struct IsAnyOf; + +/// Two types remaining, end of the list +template +struct IsAnyOf : public std::is_same { +}; + +template +struct IsAnyOf { + static constexpr bool value = (std::is_same::value || + IsAnyOf::value); +}; + +/// Tells if the specified type is one of fundamental types +/// that can be contained within a tensor. +/// We do not have raw fundamental types, rather a subset +/// of fundamental types is contained within tensors. +template +struct IsTensorContainedType : public IsAnyOf { +}; + +/// Use "IsSparseTensorContainedType::value" to test if a type T +/// is permitted as the element-type of a sparse-tensor. + +template +struct IsSparseTensorContainedType : public IsAnyOf { +}; + +/// This template's Get() returns a corresponding MLDataType +/// It dispatches the call to either GetTensorType<>() or +/// GetType<>() +template +struct GetMLDataType; + +template +struct GetMLDataType { + static MLDataType Get() { + return DataTypeImpl::GetTensorType(); + } +}; + +template +struct GetMLDataType { + static MLDataType Get() { + return DataTypeImpl::GetType(); + } +}; + +/// MapTypes helper API +/// K should always be one of the primitive data types +/// V can be either a primitive type (in which case it is a tensor) +/// or other preregistered types + +void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&, + ONNX_NAMESPACE::TypeProto&); + +template +struct SetMapTypes { + static void Set(ONNX_NAMESPACE::TypeProto& proto) { + TensorElementTypeSetter::SetMapKeyType(proto); + MLDataType dt = GetMLDataType::value>::Get(); + const auto* value_proto = dt->GetTypeProto(); + ORT_ENFORCE(value_proto != nullptr, typeid(V).name(), + " expected to be a registered ONNX type"); + CopyMutableMapValue(*value_proto, proto); + } +}; + +/// Sequence helpers +/// +// Element type is a primitive type so we set it to a tensor +void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&, + ONNX_NAMESPACE::TypeProto&); + +template +struct SetSequenceType { + static void Set(ONNX_NAMESPACE::TypeProto& proto) { + MLDataType dt = GetMLDataType::value>::Get(); + const auto* elem_proto = dt->GetTypeProto(); + ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(), + " expected to be a registered ONNX type"); + CopyMutableSeqElement(*elem_proto, proto); + } +}; + +/// OpaqueTypes helpers +/// +void AssignOpaqueDomainName(const char* domain, const char* name, + ONNX_NAMESPACE::TypeProto& proto); + +} // namespace data_types_internal + +/// All tensors base +class TensorTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + /// We first compare type_proto pointers and then + /// if they do not match try to account for the case + /// where TypeProto was created ad-hoc and not queried from MLDataType + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; + + bool IsTensorType() const override { + return true; + } + + const TensorTypeBase* AsTensorType() const override { + return this; + } + + size_t Size() const override; + + DeleteFunc GetDeleteFunc() const override; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + virtual MLDataType GetElementType() const { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + TensorTypeBase(const TensorTypeBase&) = delete; + TensorTypeBase& operator=(const TensorTypeBase&) = delete; + + protected: + ONNX_NAMESPACE::TypeProto& mutable_type_proto(); + + TensorTypeBase(); + ~TensorTypeBase() override; + + private: + struct Impl; + Impl* impl_; +}; + +/** + * \brief Tensor type. This type does not have a C++ type associated with + * it at registration time except the element type. One of the types mentioned + * above at IsTensorContainedType<> list is acceptable. + * + * \details + * Usage: + * ORT_REGISTER_TENSOR(ELEMENT_TYPE) + * Currently all of the Tensors irrespective of the dimensions are mapped to Tensor + * type. IsCompatible() currently ignores shape. + */ + +template +class TensorType : public TensorTypeBase { + public: + static_assert(data_types_internal::IsTensorContainedType::value, + "Requires one of the tensor fundamental types"); + + static MLDataType Type(); + + /// Tensors only can contain basic data types + /// that have been previously registered with ONNXRuntime + MLDataType GetElementType() const override { + return DataTypeImpl::GetType(); + } + + private: + TensorType() { + using namespace data_types_internal; + TensorElementTypeSetter::SetTensorElementType(this->mutable_type_proto()); + } +}; + +/// Common base-class for all sparse-tensors (with different element types). +class SparseTensorTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + bool IsSparseTensorType() const override { + return true; + } + + const SparseTensorTypeBase* AsSparseTensorType() const override { + return this; + } + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; + + size_t Size() const override; + + DeleteFunc GetDeleteFunc() const override; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + virtual MLDataType GetElementType() const { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + SparseTensorTypeBase(const SparseTensorTypeBase&) = delete; + SparseTensorTypeBase& operator=(const SparseTensorTypeBase&) = delete; + + protected: + ONNX_NAMESPACE::TypeProto& mutable_type_proto(); + + SparseTensorTypeBase(); + ~SparseTensorTypeBase() override; + + private: + struct Impl; + Impl* impl_; +}; + +template +class SparseTensorType : public SparseTensorTypeBase { + public: + static_assert(data_types_internal::IsSparseTensorContainedType::value, + "Requires one of the sparse-tensor fundamental types"); + + static MLDataType Type(); + + /// Return a MLDataType representing the element-type + MLDataType GetElementType() const override { + return DataTypeImpl::GetType(); + } + + private: + SparseTensorType() { + using namespace data_types_internal; + TensorElementTypeSetter::SetSparseTensorElementType(mutable_type_proto()); + } +}; + +/** + * \brief Provide a specialization for your C++ Non-tensor type + * so your implementation FromDataTypeContainer/ToDataTypeContainer + * functions correctly. Otherwise you get a default implementation + * which may not be what you need/want. + * + * This class is used to create OrtValue, fetch data from OrtValue via + * C/C++ APIs + */ +template +struct NonTensorTypeConverter { + static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) { + ORT_THROW("Not implemented"); + } + static void ToContainer(const OrtValue& /*input*/, size_t /*data_size*/, void* /*data*/) { + ORT_THROW("Not implemented"); + } +}; + +/** + * \brief Base type for all non-tensors, maps, sequences and opaques + */ +class NonTensorTypeBase : public DataTypeImpl { + public: + size_t Size() const override = 0; + + DeleteFunc GetDeleteFunc() const override = 0; + + virtual CreateFunc GetCreateFunc() const = 0; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + const NonTensorTypeBase* AsNonTensorTypeBase() const override { + return this; + } + + // \brief Override for Non-tensor types to initialize non-tensor CPP + // data representation from data. The caller of the interface + // should have a shared definition of the data which is used to initialize + // CPP data representation. This is used from C API. + // + // \param data - pointer to a data container structure non_tensor type specific + // \param data_size - size of the data container structure, used for rudimentary checks + // \param output - reference to a default constructed non-tensor type + // \returns OrtValue + // \throw if there is an error + virtual void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const; + + // \brief Override for Non-tensor types to fetch data from the internal CPP data representation + // The caller of the interface should have a shared definition of the data which is used to initialize + // CPP data representation. This is used from C API. + // + // \param input - OrtValue containing data + // \param data_size - size of the structure that is being passed for receiving data, used for + // validation + // \param data - pointer to receiving data structure + virtual void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const; + + NonTensorTypeBase(const NonTensorTypeBase&) = delete; + NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete; + + protected: + NonTensorTypeBase(); + ~NonTensorTypeBase() override; + + ONNX_NAMESPACE::TypeProto& mutable_type_proto(); + + bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; + + bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; + + bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; + + private: + struct Impl; + Impl* impl_; +}; + +// This is where T is the actual CPPRuntimeType +template +class NonTensorType : public NonTensorTypeBase { + private: + static void Delete(void* p) { + delete static_cast(p); + } + + public: + size_t Size() const override { + return sizeof(T); + } + + DeleteFunc GetDeleteFunc() const override { + return &Delete; + } + + CreateFunc GetCreateFunc() const override { + return []() -> void* { return new T(); }; + } + + protected: + NonTensorType() = default; +}; + +/** + * \brief MapType. Use this type to register + * mapping types. + * + * \param T - cpp type that you wish to register as runtime MapType + * + * \details Usage: ORT_REGISTER_MAP(C++Type) + * The type is required to have mapped_type and + * key_type defined + */ +template +class MapType : public NonTensorType { + public: + static_assert(data_types_internal::IsTensorContainedType::value, + "Requires one of the tensor fundamental types as key"); + + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { + return this->IsMapCompatible(type_proto); + } + + private: + MapType() { + using namespace data_types_internal; + SetMapTypes::Set(this->mutable_type_proto()); + } +}; + +/** + * \brief SequenceType. Use to register sequence for non-tensor types. + * + * \param T - CPP type that you wish to register as Sequence + * runtime type. + * + * \details Usage: ORT_REGISTER_SEQ(C++Type) + * The type is required to have value_type defined + */ +template +class SequenceType : public NonTensorType { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { + return this->IsSequenceCompatible(type_proto); + } + + private: + SequenceType() { + data_types_internal::SetSequenceType::Set(this->mutable_type_proto()); + } +}; + +/** + * \brief SequenceTensorTypeBase serves as a base type class for + * Tensor sequences. Akin TensorTypeBase. + * Runtime representation is always TensorSeq. + */ +class SequenceTensorTypeBase : public DataTypeImpl { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; + + bool IsTensorSequenceType() const override { + return true; + } + + const SequenceTensorTypeBase* AsSequenceTensorBase() const override { + return this; + } + + virtual MLDataType GetElementType() const { + // should never reach here. + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); + } + + size_t Size() const override; + + DeleteFunc GetDeleteFunc() const override; + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; + + SequenceTensorTypeBase(const SequenceTensorTypeBase&) = delete; + SequenceTensorTypeBase& operator=(const SequenceTensorTypeBase&) = delete; + + protected: + SequenceTensorTypeBase(); + ~SequenceTensorTypeBase(); + + ONNX_NAMESPACE::TypeProto& mutable_type_proto(); + + private: + struct Impl; + Impl* impl_; +}; + +/** + * \brief SequenceTensorType. Use to register sequence for non-tensor types. + * + * \param CPPRuntime - We always use TensorSeq + * + * \param TensorElemType - one of the primitive types + * + * \details Usage: ORT_REGISTER_SEQ_TENSOR_TYPE() + * The type is required to have value_type defined + */ +template +class SequenceTensorType : public SequenceTensorTypeBase { + public: + static MLDataType Type(); + + /// Return a MLDataType representing the element-type + MLDataType GetElementType() const override { + return DataTypeImpl::GetType(); + } + + private: + SequenceTensorType() { + data_types_internal::SetSequenceType::Set(this->mutable_type_proto()); + } +}; + +/** + * \brief OpaqueType + * + * \param T - cpp runtume that implements the Opaque type + * + * \param const char D[] - domain must be extern to be unique + * + * \param const char N[] - name must be extern to be unique + * + * \details Only one CPP type can be associated with a particular + * OpaqueType registration + * + */ +template +class OpaqueType : public NonTensorType { + public: + static MLDataType Type(); + + bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { + return this->IsOpaqueCompatible(type_proto); + } + + void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const override { + NonTensorTypeConverter::FromContainer(this, data, data_size, output); + } + + void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const override { + NonTensorTypeConverter::ToContainer(input, data_size, data); + } + + private: + OpaqueType() { + data_types_internal::AssignOpaqueDomainName(D, N, this->mutable_type_proto()); + } +}; + +/** + * \brief PrimitiveDataTypeBase + * Base class for primitive Tensor contained types + * + * \details This class contains an integer constant that can be + * used for input data type dispatching + * + */ +class PrimitiveDataTypeBase : public DataTypeImpl { + public: + bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override { + return false; + } + + const PrimitiveDataTypeBase* AsPrimitiveDataType() const override final { + return this; + } + + const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final { + return nullptr; + } + + int32_t GetDataType() const { + return data_type_; + } + + protected: + PrimitiveDataTypeBase() = default; + + void SetDataType(int32_t data_type) { + data_type_ = data_type; + } + + private: + int32_t data_type_; +}; + +/** + * \brief PrimitiveDataType + * Typed specialization for primitive types. + * Concrete instances of this class are used by Tensor. + * + * \param T - primitive data type + * + */ +template +class PrimitiveDataType : public PrimitiveDataTypeBase { + private: + static void Delete(void* p) { + delete static_cast(p); + } + + public: + static MLDataType Type(); + + size_t Size() const override { + return sizeof(T); + } + + DeleteFunc GetDeleteFunc() const override { + return &Delete; + } + + private: + PrimitiveDataType() { + this->SetDataType(data_types_internal::TensorElementTypeSetter::GetElementType()); + } +}; + +// Explicit specialization of base class template function +// is only possible within the enclosing namespace scope, +// thus a simple way to pre-instantiate a given template +// at a registration time does not currently work and the macro +// is needed. +#define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \ + template <> \ + MLDataType TensorType::Type() { \ + static TensorType tensor_type; \ + return &tensor_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetTensorType() { \ + return TensorType::Type(); \ + } + +#define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \ + template <> \ + MLDataType SparseTensorType::Type() { \ + static SparseTensorType tensor_type; \ + return &tensor_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetSparseTensorType() { \ + return SparseTensorType::Type(); \ + } + +#define ORT_REGISTER_MAP(TYPE) \ + template <> \ + MLDataType MapType::Type() { \ + static MapType map_type; \ + return &map_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return MapType::Type(); \ + } + +#define ORT_REGISTER_SEQ(TYPE) \ + template <> \ + MLDataType SequenceType::Type() { \ + static SequenceType sequence_type; \ + return &sequence_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return SequenceType::Type(); \ + } + +#define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \ + template <> \ + MLDataType SequenceTensorType::Type() { \ + static SequenceTensorType sequence_tensor_type; \ + return &sequence_tensor_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetSequenceTensorType() { \ + return SequenceTensorType::Type(); \ + } + +#define ORT_REGISTER_PRIM_TYPE(TYPE) \ + template <> \ + MLDataType PrimitiveDataType::Type() { \ + static PrimitiveDataType prim_data_type; \ + return &prim_data_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return PrimitiveDataType::Type(); \ + } + +#define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \ + template <> \ + MLDataType OpaqueType::Type() { \ + static OpaqueType opaque_type; \ + return &opaque_type; \ + } \ + template <> \ + MLDataType DataTypeImpl::GetType() { \ + return OpaqueType::Type(); \ + } +} // namespace onnxruntime diff --git a/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types_internal.h b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types_internal.h new file mode 100644 index 0000000..07534bd --- /dev/null +++ b/onnxruntime/src/main/cpp/includes/onnxruntime/core/framework/data_types_internal.h @@ -0,0 +1,501 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/data_types.h" +#include "core/graph/onnx_protobuf.h" + +#ifdef _MSC_VER +#pragma warning(push) +//TODO: fix the warning in CallableDispatchableRetHelper +#pragma warning(disable : 4702) +#endif +namespace onnxruntime { +namespace utils { + +template +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; +} + +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; +} +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT8; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT8; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT16; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT16; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT32; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT64; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_STRING; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_BOOL; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT32; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT64; +}; +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; +}; + + // The following primitives are strongly recommended for switching on tensor input datatypes for + // kernel implementations. + // + // 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros + // DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function(). + // 2) if you have a few types, use Tensor.IsDataType()/IsDataTypeString() or use utils::IsPrimitiveDataType() + // if you have a standalone MLDatatType with a sequence of if/else statements. + // 3) For something in between, we suggest to use CallDispatcher pattern. + // + // Invoking DataTypeImpl::GetType() for switching on input types is discouraged and should be avoided. + // Every primitive type carries with it an integer constant that can be used for quick switching on types. + +#define DispatchOnTensorType(tensor_type, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +//////////////////////////////////////////////////////////////////////////////// +/// Use the following primitives if you have a few types to switch on so you +// can write a short sequence of if/else statements. + +// This is a frequently used check so we make a separate utility function. +inline bool IsDataTypeString(MLDataType dt_type) { + auto prim_type = dt_type->AsPrimitiveDataType(); + return (prim_type != nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING); +} + +// Test if MLDataType is a concrete type of PrimitiveDataTypeBase +// and it is T +template +inline bool IsPrimitiveDataType(MLDataType dt_type) { + auto prim_type = dt_type->AsPrimitiveDataType(); + return (prim_type != nullptr && prim_type->GetDataType() == ToTensorProtoElementType()); +} + +// Use after AsPrimitiveDataType() is successful +// Check if PrimitiveDataTypeBase is of type T +template +inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) { + assert(prim_type != nullptr); + return prim_type->GetDataType() == ToTensorProtoElementType(); +} + +// This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226 +// GCC until very recently does not support template parameter pack expansion within lambda context. +namespace mltype_dispatcher_internal { +// T - type handled by this helper +struct CallableDispatchableHelper { + int32_t dt_type_; // Type currently dispatched + size_t called_; + + explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {} + + // Must return integer to be in a expandable context + template + int Invoke(Fn&& fn, Args&&... args) { + if (utils::ToTensorProtoElementType() == dt_type_) { + std::forward(fn)(std::forward(args)...); + ++called_; + } + return 0; + } +}; + +// Default policy is to throw with no return type. +template +struct UnsupportedTypeDefaultPolicy { + Ret operator()(int32_t dt_type) const { + ORT_THROW("Unsupported data type: ", dt_type); + } +}; + +// Helper with the result type +template > +struct CallableDispatchableRetHelper { + int32_t dt_type_; // Type currently dispatched + size_t called_; + Ret result_; + + explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {} + + Ret Get() { + // See if there were multiple invocations.It is a bug. + ORT_ENFORCE(called_ < 2, "Check for duplicate types in MLTypeCallDispatcherRet"); + // No type was invoked + if (called_ == 0) { + result_ = UnsupportedPolicy()(dt_type_); + } + return result_; + } + + // Must return integer to be in a expandable context + template + int Invoke(Fn&& fn, Args&&... args) { + if (utils::ToTensorProtoElementType() == dt_type_) { + result_ = std::forward(fn)(std::forward(args)...); + ++called_; + } + return 0; + } +}; + +} // namespace mltype_dispatcher_internal + +// This class helps to efficiently dispatch calls for templated +// kernel implementation functions that has no return value. +// If your implementation function must return a value such as Status +// Use MLTypeCallDispatcherRet class. +// +// The first template parameter is a template struct/class functor +// that must implement operator() with arbitrary number of arguments +// and void return turn. It must return Ret type if you are using MLTypeCallDispatcherRet. +// Fn must be default constructible. +// +// Types is a type list that are supported by this kernel implementation. +// There should be no duplicate types. An exception will be thrown if there +// a duplicate. +// +// The constructor accepts an enum that is obtained from +// input_tensor->DataType()->AsPrimitiveType()->GetDataType(). +// Fn will be called only once the type designated by dt_type value. +// If current dt_type is not handled, the Dispatcher will throw an exception. +// +template